one-hot 编码



def onehot(labels):
  '''one-hot 编码'''
  #数据有几行输出
  n_sample = len(labels)
  #数据分为几类。因为编码从0开始所以要加1
  n_class = max(labels) + 1
  #建立一个batch所需要的数组,全部赋0.
  onehot_labels = np.zeros((n_sample, n_class))
  #对每一行的,对应分类赋1
  onehot_labels[np.arange(n_sample), labels] = 1
  return onehot_labels

运行结果:

label=np.array([0,1,2])

onehot(label)
Out[8]:
array([[ 1., 0., 0.],
[ 0., 1., 0.],
[ 0., 0., 1.]])

label=np.array([0,4,7,1,1,1,4,1])

onehot(label)
Out[10]:
array([[ 1., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 1.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0.]])

总结:本次标签只有一类,如第一个标签为一类,有两种情况。第二个为标签一类,有七种情况。如果标签为两类,比如{男生,女生}、{一年级、二年级、三年级},那编码的长度为5.

onehot标签则是顾名思义,一个长度为n的数组,只有一个元素是1.0,其他元素是0.0。

想想为什么要这样编码,知乎大佬的的一个解释感觉极有道理。

使用onehot的直接原因是现在多分类cnn网络的输出通常是softmax层,而它的输出是一个概率分布,从而要求输入的标签也以概率分布的形式出现,进而算交叉熵之类。
onehot其实就是给出了,真是的样本真实概率分布,其中一个样本数据概率为1,其他全为0.。计算损失交叉熵时,直接用1*log(1/概率),就直接算出了交叉熵,作为损失。



 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM