one-hot 編碼



def onehot(labels):
  '''one-hot 編碼'''
  #數據有幾行輸出
  n_sample = len(labels)
  #數據分為幾類。因為編碼從0開始所以要加1
  n_class = max(labels) + 1
  #建立一個batch所需要的數組,全部賦0.
  onehot_labels = np.zeros((n_sample, n_class))
  #對每一行的,對應分類賦1
  onehot_labels[np.arange(n_sample), labels] = 1
  return onehot_labels

運行結果:

label=np.array([0,1,2])

onehot(label)
Out[8]:
array([[ 1., 0., 0.],
[ 0., 1., 0.],
[ 0., 0., 1.]])

label=np.array([0,4,7,1,1,1,4,1])

onehot(label)
Out[10]:
array([[ 1., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 1.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.]])

總結:本次標簽只有一類,如第一個標簽為一類,有兩種情況。第二個為標簽一類,有七種情況。如果標簽為兩類,比如{男生,女生}、{一年級、二年級、三年級},那編碼的長度為5.

onehot標簽則是顧名思義,一個長度為n的數組,只有一個元素是1.0,其他元素是0.0。

想想為什么要這樣編碼,知乎大佬的的一個解釋感覺極有道理。

使用onehot的直接原因是現在多分類cnn網絡的輸出通常是softmax層,而它的輸出是一個概率分布,從而要求輸入的標簽也以概率分布的形式出現,進而算交叉熵之類。
onehot其實就是給出了,真是的樣本真實概率分布,其中一個樣本數據概率為1,其他全為0.。計算損失交叉熵時,直接用1*log(1/概率),就直接算出了交叉熵,作為損失。



 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM