PyTorch筆記--交叉熵損失函數實現


交叉熵(cross entropy):用於度量兩個概率分布間的差異信息。交叉熵越小,代表這兩個分布越接近。

函數表示(這是使用softmax作為激活函數的損失函數表示):

是真實值,是預測值。)

命名說明:

pred=F.softmax(logits),logits是softmax函數的輸入,pred代表預測值,是softmax函數的輸出。

pred_log=F.log_softmax(logits),pred_log代表對預測值再取對數后的結果。也就是將logits作為log_softmax()函數的輸入。

方法一,使用log_softmax()+nll_loss()實現

torch.nn.functional.log_softmax(input)

  對輸入使用softmax函數計算,再取對數。

torch.nn.functional.nll_loss(input, target)

  input是經log_softmax()函數處理后的結果,pred_log

  target代表的是真實值。

  有了這兩個輸入后,該函數對其實現交叉熵損失函數的計算,即上面公式中的L。

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.randn(1, 28)
>>> w = torch.randn(10,28)
>>> logits = x @ w.t()
>>> pred_log = F.log_softmax(logits, dim=1)
>>> pred_log
tensor([[ -0.8779,  -6.7271,  -9.1801,  -6.8515,  -9.6900,  -6.3061,  -3.7304,
          -8.1933, -11.5704,  -0.5873]])
>>> F.nll_loss(pred_log, torch.tensor([3]))
tensor(6.8515)

logits的維度是(1, 10)這里可以理解成是1個輸入,最終可能得到10個分類的結果中的一個。pred_log就是

這里的參數target=torch.tensor([3]),我的理解是,他代表真正的分類的值是在第4類(從0編號)。

使用獨熱編碼代表真實值是[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],即這個輸入它是屬於第4類的。

根據上述公式進行計算,現在我們 都已經知道了。

對其進行點乘操作

 

 

 

 方法二,使用cross_entropy()實現

torch.nn.functional.cross_entropy(input, target)

  這里的input是沒有經過處理的logits,這個函數會自動根據logits計算出pred_log

  target是真實值

>>> import torch
>>> import torch.nn.functional as F
>>> x = torch.randn(1, 28)
>>> w = torch.randn(10,28)
>>> logits = x @ w.t()
>>> F.cross_entropy(logits, torch.tensor([3]))
tensor(6.8515)

這里我刪除了上面使用方法一的代碼部分,x和w沒有重新隨機生成,所以計算結果是一樣的。

 

對於分類任務,交叉熵相對均方誤差效果更好。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM