pytorch各種損失函數


官方文檔:https://pytorch.org/docs/stable/nn.html#loss-functions

1:torch.nn.L1Loss

mean absolute error (MAE) between each element in the input x and target y .

MAE是指平均絕對誤差,也稱L1損失:

[公式]

loss = nn.L1Loss()
input = torch.randn(1, 2, requires_grad=True)
target = torch.randn(1, 2)
output = loss(input, target)

 

2:torch.nn.MSELoss

measures the mean squared error (squared L2 norm) between each element in the input x and target y .

loss = nn.MSELoss()
input = torch.randn(1, 2, requires_grad=True)
target = torch.randn(1, 2)
output = loss(input, target)

 

3:torch.nn.NLLLoss && torch.nn.CrossEntropyLoss

torch.nn.NLLLoss是用於多分類的負對數似然損失函數(negative log likelihood loss)

 torch.nn.CrossEntropyLoss是交叉熵損失函數

二者的區別:

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3,5,requires_grad=True)
#each element in target has to have 0 <= value < C
target = target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(m(input), target)
print(output)

loss = nn.CrossEntropyLoss()
output = loss(input, target)
print(output)

 

 

4:torch.nn.BCELoss && torch.nn.BCEWithLogitsLoss

衡量目標和輸出之間二進制交叉熵的criterion

 

 N表示batch size,xn為輸出,yn為目標

二者的區別:

m = torch.nn.Sigmoid()
loss = torch.nn.BCELoss()
input = torch.randn(3,requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
print(output)

loss = torch.nn.BCEWithLogitsLoss()
output = loss(input, target)
print(output)

 

ref:https://www.cnblogs.com/wanghui-garcia/p/10862733.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM