PyTorch笔记--交叉熵损失函数实现


交叉熵(cross entropy):用于度量两个概率分布间的差异信息。交叉熵越小,代表这两个分布越接近。

函数表示(这是使用softmax作为激活函数的损失函数表示):

是真实值,是预测值。)

命名说明:

pred=F.softmax(logits),logits是softmax函数的输入,pred代表预测值,是softmax函数的输出。

pred_log=F.log_softmax(logits),pred_log代表对预测值再取对数后的结果。也就是将logits作为log_softmax()函数的输入。

方法一,使用log_softmax()+nll_loss()实现

torch.nn.functional.log_softmax(input)

  对输入使用softmax函数计算,再取对数。

torch.nn.functional.nll_loss(input, target)

  input是经log_softmax()函数处理后的结果,pred_log

  target代表的是真实值。

  有了这两个输入后,该函数对其实现交叉熵损失函数的计算,即上面公式中的L。

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.randn(1, 28)
>>> w = torch.randn(10,28)
>>> logits = x @ w.t()
>>> pred_log = F.log_softmax(logits, dim=1)
>>> pred_log
tensor([[ -0.8779,  -6.7271,  -9.1801,  -6.8515,  -9.6900,  -6.3061,  -3.7304,
          -8.1933, -11.5704,  -0.5873]])
>>> F.nll_loss(pred_log, torch.tensor([3]))
tensor(6.8515)

logits的维度是(1, 10)这里可以理解成是1个输入,最终可能得到10个分类的结果中的一个。pred_log就是

这里的参数target=torch.tensor([3]),我的理解是,他代表真正的分类的值是在第4类(从0编号)。

使用独热编码代表真实值是[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],即这个输入它是属于第4类的。

根据上述公式进行计算,现在我们 都已经知道了。

对其进行点乘操作

 

 

 

 方法二,使用cross_entropy()实现

torch.nn.functional.cross_entropy(input, target)

  这里的input是没有经过处理的logits,这个函数会自动根据logits计算出pred_log

  target是真实值

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.randn(1, 28)
>>> w = torch.randn(10,28)
>>> logits = x @ w.t()
>>> F.cross_entropy(logits, torch.tensor([3]))
tensor(6.8515)

这里我删除了上面使用方法一的代码部分,x和w没有重新随机生成,所以计算结果是一样的。

 

对于分类任务,交叉熵相对均方误差效果更好。


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM