def onehot(labels):
'''one-hot 編碼'''
#數據有幾行輸出
n_sample = len(labels)
#數據分為幾類。因為編碼從0開始所以要加1
n_class = max(labels) + 1
#建立一個batch所需要的數組,全部賦0.
onehot_labels = np.zeros((n_sample, n_class))
#對每一行的,對應分類賦1
onehot_labels[np.arange(n_sample), labels] = 1
return onehot_labels
運行結果:
label=np.array([0,1,2])
onehot(label)
Out[8]:
array([[ 1., 0., 0.],
[ 0., 1., 0.],
[ 0., 0., 1.]])
label=np.array([0,4,7,1,1,1,4,1])
onehot(label)
Out[10]:
array([[ 1., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 1.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.]])
總結:本次標簽只有一類,如第一個標簽為一類,有兩種情況。第二個為標簽一類,有七種情況。如果標簽為兩類,比如{男生,女生}、{一年級、二年級、三年級},那編碼的長度為5.
onehot標簽則是顧名思義,一個長度為n的數組,只有一個元素是1.0,其他元素是0.0。
想想為什么要這樣編碼,知乎大佬的的一個解釋感覺極有道理。
使用onehot的直接原因是現在多分類cnn網絡的輸出通常是softmax層,而它的輸出是一個概率分布,從而要求輸入的標簽也以概率分布的形式出現,進而算交叉熵之類。