pytorch 分類問題用到的分類器(F.CROSS_ENTROPY和F.BINARY_CROSS_ENTROPY_WITH_LOGITS)


推薦參考:https://www.freesion.com/article/4488859249/

實際運用時注意:

F.binary_cross_entropy_with_logits()對應的類是torch.nn.BCEWithLogitsLoss,在使用時會自動添加sigmoid,然后計算loss。(其實就是nn.sigmoid和nn.BCELoss的合體)

total = model(xi, xv)  # 回到forward函數 , 返回 100*1維
loss = criterion(total, y)  # y是label,整型 0或1
preds = (F.sigmoid(total) > 0.5)  # 配合sigmoid使用
train_num_correct += (preds == y).sum()

 

想一想,其實交叉熵就是-sum(y_true * log(y_pred)),鏈接中的公式中,由於只有y_true等於1時計算才有效,所以可以化簡,同時y_pred經過了softmax處理


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM