一、pytorch中各損失函數的比較
Pytorch中Softmax、Log_Softmax、NLLLoss以及CrossEntropyLoss的關系與區別詳解
Pytorch詳解BCELoss和BCEWithLogitsLoss
總結這兩篇博客的內容就是:
- CrossEntropyLoss函數包含Softmax層、log和NLLLoss層,適用於單標簽任務,主要用在單標簽多分類任務上,當然也可以用在單標簽二分類上。
- BCEWithLogitsLoss函數包括了Sigmoid層和BCELoss層,適用於二分類任務,可以是單標簽二分類,也可以是多標簽二分類任務。
- 以上這幾個損失函數本質上都是交叉熵損失函數,只不過是適用范圍不同而已。
第一條的原因是:
也就是說,各個class的得分是互斥的,這個class得分多了,另個class的得分會減少。
第二條的原因是:
也就是說,各個class的得分是獨立的,互不影響,所以可以進行多標簽預測。
二、程序示例
在使用中,最常遇到的情況是,CrossEntropyLoss的input是一個二維張量,target是一維張量,例如:
loss = nn.CrossEntropyLoss() input = torch.randn(3, 5, requires_grad=True) # 3個樣本,5個類別 target = torch.empty(3, dtype=torch.long).random_(5) # torch.long表示長整型,torch.empty(3)表示產生一維向量,長度為3,元素內容為空。 # random_(5)表示用0到4的整數去填充3個空元素。之所以是整數,是因為前面規定了torch.long。 output = loss(input, target) output.backward()
CrossEntropyLoss的計算公式為(本質上是交叉熵公式+softmax公式):
BCEWithLogitsLoss和BCELoss的input和target必須保持維度相同,即同時是一維張量,或者同時是二維張量,例如:
m = nn.Sigmoid() loss = nn.BCELoss() # input和target同為一維張量 input = torch.randn(3, requires_grad=True) target = torch.empty(3).random_(2) # 單標簽二分類任務 output = loss(m(input), target) output.backward() # input和target同為二維張量 input = torch.randn([5, 3], requires_grad=True) target = torch.empty([5, 3]).random_(2) # 多標簽二分類任務 output = loss(m(input), target) output.backward()
-------------------------------------------
loss = nn.BCEWithLogitsLoss() # input和target同為一維張量 input = torch.randn(3, requires_grad=True) target = torch.empty(3).random_(2) # 單標簽二分類任務 output = loss(input, target) output.backward() # input和target同為二維張量 input = torch.randn([5,3], requires_grad=True) target = torch.empty([5,3]).random_(2) # 多標簽二分類任務 output = loss(input, target) output.backward()
三、交叉熵損失函數的推導
以下的內容摘自知乎:交叉熵、相對熵(KL散度)、JS散度和Wasserstein距離(推土機距離)
對於二分類問題,假設是貓和狗的分類問題,則p(x=貓)=1-p(x=狗),同樣地q(x=貓)=1-q(x=狗),所以,對於某一張圖片(樣本),它的損失可通過如下公式計算:
這個二分類公式其實是cross entropy between two Bernoulli distribution。這個公式不僅可以用於單標簽的二分類問題,也可以用於多標簽的二分類問題。在pytorch的BCEWithLogitsLoss函數或者BCELoss函數中,實際計算公式是這樣的:
式中,n是指總的類別數目,這個公式指的是單個樣本的損失。對單標簽二分類時,即當n=2時,(2)式和(1)式等價,證明:
簡單的算例證明可以參考知乎:pytorch中的損失函數總結 第6小節