pytorch 分割二分類的兩種形式


1、單通道輸出

在訓練時,輸出通道為1,網絡的輸出數值是任意的。標簽是單通道的二值圖,對輸出使用sigmoid,使其數值歸一化到[0,1],然后和標簽做交叉熵損失。

訓練結束后,將輸出的output經過sigmoid函數,然后取閾值(一般為0.5),大於閾值則為1否則取0,從而得到最終的預測結果。

 

代碼實現:

#第一種
output = net(input)  # net的最后一層沒有使用sigmoid
Loss = torch.nn.BCEWithLogitsLoss()#會先做sigmoid然后求交叉熵
loss = Loss(output, target)

#第二種
output = net(input)  # net的最后一層沒有使用sigmoid
output = F.sigmoid(output)
Loss = torch.nn.BCEWithLoss()
loss = Loss(output, target)

#預測
output = net(input)  # net的最后一層沒有使用sigmoid
output = F.sigmoid(output)
predict = torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output)

  

2、二(多)通道輸出

在訓練時,輸出通道為2,網絡的輸出數值是任意的。讓網絡的輸出經過softmax,歸一化到[0,1],在各通道中,同一位置加起來的數值會等於1。標簽是單通道的二值圖,首先使用one-hot編碼,使其變為二通道,當前通道值為1,另一通道上就為0。然后將輸出和標簽做交叉熵損失。

訓練結束后,取每個像素位置上對應最大值的通道序號為最終的預測值,從而得到最終的預測結果。

 

代碼實現:

#訓練
output = net(input)  # net的最后一層沒有使用sigmoid
Loss = torch.nn.CrossEntropyLoss()
loss = Loss(output, target)

#預測
output = net(input)  # net的最后一層沒有使用sigmoid
predict = output.argmax(dim=1) 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM