1、單通道輸出
在訓練時,輸出通道為1,網絡的輸出數值是任意的。標簽是單通道的二值圖,對輸出使用sigmoid,使其數值歸一化到[0,1],然后和標簽做交叉熵損失。
訓練結束后,將輸出的output經過sigmoid函數,然后取閾值(一般為0.5),大於閾值則為1否則取0,從而得到最終的預測結果。
代碼實現:
#第一種 output = net(input) # net的最后一層沒有使用sigmoid Loss = torch.nn.BCEWithLogitsLoss()#會先做sigmoid然后求交叉熵 loss = Loss(output, target) #第二種 output = net(input) # net的最后一層沒有使用sigmoid output = F.sigmoid(output) Loss = torch.nn.BCEWithLoss() loss = Loss(output, target) #預測 output = net(input) # net的最后一層沒有使用sigmoid output = F.sigmoid(output) predict = torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output)
2、二(多)通道輸出
在訓練時,輸出通道為2,網絡的輸出數值是任意的。讓網絡的輸出經過softmax,歸一化到[0,1],在各通道中,同一位置加起來的數值會等於1。標簽是單通道的二值圖,首先使用one-hot編碼,使其變為二通道,當前通道值為1,另一通道上就為0。然后將輸出和標簽做交叉熵損失。
訓練結束后,取每個像素位置上對應最大值的通道序號為最終的預測值,從而得到最終的預測結果。
代碼實現:
#訓練 output = net(input) # net的最后一層沒有使用sigmoid Loss = torch.nn.CrossEntropyLoss() loss = Loss(output, target) #預測 output = net(input) # net的最后一層沒有使用sigmoid predict = output.argmax(dim=1)