n = 5 #類別數 indices = torch.randint(0, n, size=(15,15)) #生成數組元素0~5的二維數組(15*15) one_hot = torch.nn.functional.one_hot(indices, n) #size=(15, 15, n)
1. One-hot編碼(一維數組、二維圖像都可以):label = torch.nn.functional.one_hot(label, N)。 #一維數組的one hot編碼,N為類別,label為數組
ps. (1)把數組(m,n)轉換成(a,b,c),reshape/view時是將前者逐行讀取,轉換成后者的。
(2)還會補充one-hot編碼轉換成單通道圖像的方法。
2. One-hot編碼---label
對於一維數組,results = one_hot_label.argmax(dim=1, keepdim=True)
或者a = [np.argmax(l) for l in one_hot] #將onehot編碼轉成一般編碼