Pytorch技巧


one-hot encoding和常規label的轉化

常規label指0,1,2,3,4,5,......(一個數代表一類)

#常規label轉one-hot向量
def encode_onehot(labels):            #用單位矩陣來構建onehot向量
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in        #單位矩陣
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot
#one-hot向量轉常規label
labels = torch.LongTensor(np.where(labels)[1])

增減layer:

參考 https://www.cnblogs.com/marsggbo/p/8781774.html

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM