總結: NLLLoss, CrossEntropyLoss, BCELoss, BCEWithLogitsLoss比較,以及交叉熵損失函數推導


一、pytorch中各損失函數的比較

Pytorch中Softmax、Log_Softmax、NLLLoss以及CrossEntropyLoss的關系與區別詳解

Pytorch詳解BCELoss和BCEWithLogitsLoss

 

總結這兩篇博客的內容就是:

  1. CrossEntropyLoss函數包含Softmax層、log和NLLLoss層,適用於單標簽任務,主要用在單標簽多分類任務上,當然也可以用在單標簽二分類上。
  2. BCEWithLogitsLoss函數包括了Sigmoid層和BCELoss層,適用於二分類任務,可以是單標簽二分類,也可以是多標簽二分類任務。
  • 以上這幾個損失函數本質上都是交叉熵損失函數,只不過是適用范圍不同而已。

第一條的原因是:

也就是說,各個class的得分是互斥的,這個class得分多了,另個class的得分會減少。

第二條的原因是:

也就是說,各個class的得分是獨立的,互不影響,所以可以進行多標簽預測。

 

二、程序示例

在使用中,最常遇到的情況是,CrossEntropyLoss的input是一個二維張量,target是一維張量,例如:

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)   # 3個樣本,5個類別
target = torch.empty(3, dtype=torch.long).random_(5)   
# torch.long表示長整型,torch.empty(3)表示產生一維向量,長度為3,元素內容為空。
# random_(5)表示用0到4的整數去填充3個空元素。之所以是整數,是因為前面規定了torch.long。

output = loss(input, target)
output.backward()

CrossEntropyLoss的計算公式為(本質上是交叉熵公式+softmax公式):

BCEWithLogitsLoss和BCELoss的input和target必須保持維度相同,即同時是一維張量,或者同時是二維張量,例如:

m = nn.Sigmoid()
loss = nn.BCELoss()

# input和target同為一維張量
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)   # 單標簽二分類任務
output = loss(m(input), target)
output.backward()

# input和target同為二維張量
input = torch.randn([5, 3], requires_grad=True)
target = torch.empty([5, 3]).random_(2)   # 多標簽二分類任務
output = loss(m(input), target)
output.backward()

-------------------------------------------

loss = nn.BCEWithLogitsLoss()

# input和target同為一維張量
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)   # 單標簽二分類任務
output = loss(input, target)
output.backward()

# input和target同為二維張量
input = torch.randn([5,3], requires_grad=True)
target = torch.empty([5,3]).random_(2)   # 多標簽二分類任務
output = loss(input, target)
output.backward()

三、交叉熵損失函數的推導

以下的內容摘自知乎:交叉熵、相對熵(KL散度)、JS散度和Wasserstein距離(推土機距離)

對於二分類問題,假設是貓和狗的分類問題,則p(x=貓)=1-p(x=狗),同樣地q(x=貓)=1-q(x=狗),所以,對於某一張圖片(樣本),它的損失可通過如下公式計算:

這個二分類公式其實是cross entropy between two Bernoulli distribution。這個公式不僅可以用於單標簽的二分類問題,也可以用於多標簽的二分類問題。在pytorch的BCEWithLogitsLoss函數或者BCELoss函數中,實際計算公式是這樣的:

式中,n是指總的類別數目,這個公式指的是單個樣本的損失。對單標簽二分類時,即當n=2時,(2)式和(1)式等價,證明:

簡單的算例證明可以參考知乎:pytorch中的損失函數總結 第6小節


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM