Pytorch nn.BCEWithLogitsLoss() 的簡單理解與用法


這個東西,本質上和nn.BCELoss()沒有區別,只是在BCELoss上加了個logits函數(也就是sigmoid函數),例子如下:

import torch import torch.nn as nn label = torch.Tensor([1, 1, 0]) pred = torch.Tensor([3, 2, 1]) pred_sig = torch.sigmoid(pred) loss = nn.BCELoss() print(loss(pred_sig, label)) loss = nn.BCEWithLogitsLoss() print(loss(pred, label)) loss = nn.BCEWithLogitsLoss() print(loss(pred_sig, label))

輸出結果分別為:

tensor(0.4963) tensor(0.4963) tensor(0.5990)

可以看到,nn.BCEWithLogitsLoss()相當於是在nn.BCELoss()中預測結果pred的基礎上先做了個sigmoid,然后繼續正常算loss。所以這就涉及到一個比較奇葩的bug,如果網絡本身在輸出結果的時候已經用sigmoid去處理了,算loss的時候用nn.BCEWithLogitsLoss()…那么就會相當於預測結果算了兩次sigmoid,可能會出現各種奇奇怪怪的問題——

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM