詳細理論部分可參考https://www.cnblogs.com/wanghui-garcia/p/10862733.html
BCELoss()和BCEWithLogitsLoss()的輸出logits和目標labels(必須是one_hot形式)的形狀相同。
CrossEntropyLoss()的目標labels的形狀是[3, 1](以下面為例,不能是one_hot形式),輸出logits是[3, 2]。如果是多分類,labels的形狀是[batch, 1],值為0~num_classes-1之間。
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 5 m = nn.Sigmoid() 6 7 loss_f1 = nn.BCELoss() 8 loss_f2 = nn.BCEWithLogitsLoss() 9 loss_f3 = nn.CrossEntropyLoss() 10 11 logits = torch.randn(3, 2) 12 labels = torch.FloatTensor([[0, 1], [1, 0], [1, 0]]) 13 14 print(loss_f1(m(logits), labels)) #tensor(0.9314),注意logits先被激活函數作用 15 print(loss_f2(logits, labels)) #tensor(0.9314) 16 17 label2 = torch.LongTensor([1, 0, 0]) 18 print(loss_f3(logits, label2)) #tensor(1.2842) 19 20 logits3 = torch.randn(3, 10) #如果是十分類 21 label3 = torch.LongTensor([9, 2, 5]) 22 print(loss_f3(logits3, label3)) #tensor(2.6467)
如果label2也想變成labels,然后通過BCELoss進行計算的話,可以先轉變成獨熱編碼的形式:
1 encode = F.one_hot(label2, num_classes = 2) #encode的值和labels一樣,但是類型是LongTensor 2 print(loss_f1(m(logits), encode.type(torch.float32))) #tensor(0.9314)
Tip:BCEWithLogitsLoss()可以用於多標簽分類,將最后分類層的每個輸出節點使用sigmoid激活函數激活,然后對每個輸出節點和對應的標簽計算交叉熵損失函數。
1 import torch 2 import numpy as np 3 4 pred = np.array([[-0.4089, -1.2471, 0.5907], 5 [-0.4897, -0.8267, -0.7349], 6 [0.5241, -0.1246, -0.4751]]) 7 label = np.array([[0, 1, 1], 8 [0, 0, 1], 9 [1, 0, 1]]) 10 11 pred = torch.from_numpy(pred).float() 12 label = torch.from_numpy(label).float() 13 14 crition1 = torch.nn.BCEWithLogitsLoss() 15 loss1 = crition1(pred, label) 16 print(loss1) #tensor(0.7193) 17 18 crition2 = torch.nn.MultiLabelSoftMarginLoss() 19 loss2 = crition2(pred, label) 20 print(loss2) #tensor(0.7193)