關於pytorch語義分割二分類問題的兩種做法


形式1:輸出為單通道

分析

即網絡的輸出 output 為 [batch_size, 1, height, width] 形狀。其中 batch_szie 為批量大小,1 表示輸出一個通道,heightwidth 與輸入圖像的高和寬保持一致。

在訓練時,輸出通道數是 1,網絡得到的 output 包含的數值是任意的數。給定的 target ,是一個單通道標簽圖,數值只有 0 和 1 這兩種。為了讓網絡輸出 output 不斷逼近這個標簽,首先會讓 output 經過一個sigmoid 函數,使其數值歸一化到[0, 1],得到 output1 ,然后讓這個 output1target 進行交叉熵計算,得到損失值,反向傳播更新網絡權重。最終,網絡經過學習,會使得 output1 逼近target

訓練結束后,網絡已經具備讓輸出的 output 經過轉換從而逼近 target 的能力。首先將輸出的 output 通過sigmoid 函數,然后取一個閾值(一般設置為0.5),大於閾值則取1反之則取0,從而得到預測圖 predict。后續則是一些評估相關的計算。

代碼實現

在這個過程中,訓練的損失函數為二進制交叉熵損失函數,然后根據輸出是否用到了sigmoid有兩種可選的pytorch實現方式:

output = net(input)  # net的最后一層沒有使用sigmoid
loss_func1 = torch.nn.BCEWithLogitsLoss()
loss = loss_func1(output, target)

當網絡最后一層沒有使用sigmoid時,需要使用 torch.nn.BCEWithLogitsLoss() ,顧名思義,在這個函數中,拿到output首先會做一個sigmoid操作,再進行二進制交叉熵計算。上面的操作等價於

output = net(input)  # net的最后一層沒有使用sigmoid
output = F.sigmoid(output)
loss_func1 = torch.nn.BCEWithLoss()
loss = loss_func1(output, target)

當然,你也可以在網絡最后一層加上sigmoid操作。從而省去第二行的代碼(在預測時也可以省去)。

在預測試時,可用下面的代碼實現預測圖的生成

output = net(input)  # net的最后一層沒有使用sigmoid
output = F.sigmoid(output)
predict = torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output))
...

即大於0.5的記為1,小於0.5記為0。

形式2:輸出為多通道

分析

即網絡的輸出 output 為 [batch_size, num_class, height, width] 形狀。其中 batch_szie 為批量大小,num_class 表示輸出的通道數與分類數量一致,heightwidth 與輸入圖像的高和寬保持一致。

在訓練時,輸出通道數是 num_class(這里取2),網絡得到的 output 包含的數值是任意的數。給定的 target ,是一個單通道標簽圖,數值只有 0 和 1 這兩種。為了讓網絡輸出 output 不斷逼近這個標簽,首先會讓 output 經過一個 softmax 函數,使其數值歸一化到[0, 1],得到 output1 ,在各通道中,這個數值加起來會等於1。對於target 他是一個單通道圖,首先使用onehot編碼,轉換成 num_class個通道的圖像,每個通道中的取值是根據單通道中的取值計算出來的,例如單通道中的第一個像素取值為1(0<= 1 <=num_class-1,這里num_class=2),那么onehot編碼后,在第一個像素的位置上,兩個通道的取值分別為0,1。也就是說像素的取值決定了對應序號的通道取1,其他的通道取0,這個非常關鍵。上面的操作執行完后得到target1,讓這個 output1target1 進行交叉熵計算,得到損失值,反向傳播更新網路權重。最終,網絡經過學習,會使得 output1 逼近target1(在各通道層面上)。

訓練結束后,網絡已經具備讓輸出的 output 經過轉換從而逼近 target 的能力。計算 output 中各通道每一個像素位置上,取值最大的那個對應的通道序號,從而得到預測圖 predict。后續則是一些評估相關的計算。

代碼實現

在這個過程中,則可以使用交叉熵損失函數:

output = net(input)  # net的最后一層沒有使用sigmoid
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(output, target)

根據前面的分析,我們知道,正常的output是 [batch_size, num_class, height, width]形狀的,而target是[batch_size, height, width]形狀的,需要按照上面的分析進行轉換才可以計算交叉熵,而在pytorch中,我們不需要進一步做這個處理,直接使用就可以了。

在預測試時,使用下面的代碼實現預測圖的生成

output = net(input)  # net的最后一層沒有使用sigmoid
predict = output.argmax(dim=1)
...

即得到輸出后,在通道方向上找出最大值所在的索引號。

小結

總的來說,我覺得第二種方式更值得推廣,一方面不用考慮閾值的選取問題;另一方面,該方法同樣適用於多類別的語義分割任務,通用性更強。

參考資料

[1]https://blog.csdn.net/longshaonihaoa/article/details/105253553

[2]https://cuijiahua.com/blog/2020/03/dl-16.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM