one-hot encoding和常規label的轉化
常規label指0,1,2,3,4,5,......(一個數代表一類)
#常規label轉one-hot向量 def encode_onehot(labels): #用單位矩陣來構建onehot向量 classes = set(labels) classes_dict = {c: np.identity(len(classes))[i, :] for i, c in #單位矩陣 enumerate(classes)} labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32) return labels_onehot
#one-hot向量轉常規label labels = torch.LongTensor(np.where(labels)[1])
增減layer:
參考 https://www.cnblogs.com/marsggbo/p/8781774.html