- NLLLoss 和 CrossEntropyLoss
在圖片單標簽分類時,輸入m張圖片,輸出一個m*N的Tensor,其中N是分類個數。比如輸入3張圖片,分3類,最后的輸出是一個3*3的Tensor
input = torch.tensor([[-0.1123, -0.6028, -0.0450],
[ 0.1596, 0.2215, -1.0176],
[-0.2359, -0.7898, 0.7097]])
第123行分別是第123張圖片的結果,假設第123列分別是貓、狗和豬的分類得分。
first step: 對每一行使用Softmax,這樣可以得到每張圖片的概率分布。概率最大的為:1:豬;2:狗;3:豬。
sm = torch.nn.Softmax(dim=1)
sm(input)
tensor([[0.3729, 0.2283, 0.3988],
[0.4216, 0.4485, 0.1299],
[0.2410, 0.1385, 0.6205]])
second step: 對softmax結果取對數
torch.log(sm(input))
tensor([[-0.9865, -1.4770, -0.9192],
[-0.8637, -0.8019, -2.0409],
[-1.4229, -1.9767, -0.4773]])
Softmax后的數值都在0~1之間,所以log之后值域是負無窮到0。
NLLLoss的結果就是把上面的輸出與Label對應的那個值拿出來,再去掉負號,再求均值。
假設我們現在Target是[0,2,1](第一張圖片是貓,第二張是豬,第三張是狗)。第一行取第0個元素,第二行取第2個,第三行取第1個,去掉負號,結果是:[0.9865,2.0409,1.9767]。再求個均值,結果是:1.66
對比NLLLoss的結果
loss = torch.nn.NLLLoss()
loss(torch.log(sm(input)),target)
# 1.6681
CrossEntropyLoss 相當於上述步驟的組合,Softmax–Log–NLLLoss合並成一步
loss2 = torch.nn.CrossEntropyLoss()
loss2(input,target)
# 1.6681