語義分割單通道和多通道輸出交叉熵損失函數的計算問題


摘要

本文驗證了語義分割任務下,單通道輸出和多通道輸出時,使用交叉熵計算損失值的細節問題。對比驗證了使用簡單的函數和自帶損失函數的結果,通過驗證,進一步加強了對交叉熵的理解。

交叉熵損失函數

交叉熵損失函數的原理和推導過程,可以參考這篇博文,交叉熵的計算公式如下:

\[CE(p,q) = -p*log(q) \]

其中 \(q\) 為預測的概率,\(q∈[0,1]\)\(p\) 為標簽,\(p∈\{0,1\}\)

而交叉熵損失函數則是利用上式計算每一個分類的交叉熵之和。對於概率,所有分類的概率 \(q\) 之和滿足相加等於1,而對於標簽,則需要進行one-hot編碼,使得有且只有一個分類的 \(p\) 為1,其余的分類為0。

單通道輸出時的交叉熵損失計算

單通道輸出交叉熵損失計算示意圖

首先,假設我們研究的是一個二分類語義分割問題。

網絡的輸入是一個 2×2 的圖像,設置 batch_size 為 2,網絡輸出單通道特征圖。網絡的標簽也是一個 2 ×2 的二進制掩模圖(即只有0和1的單通道圖像)。

我們在 pytorch 中將其定義:

import torch

# 假設輸出一個 [batch_size=2, channel=1, height=2, width=2] 格式的張量 x1
x1 = torch.tensor(
    [[[[ 0.43, -0.25],
        [-0.32, 0.69]]],

        [[[-0.29, 0.37],
          [0.54,  -0.72]]]])

# 假設標簽圖像為與 x1 同型的張量 y1
y1 = torch.tensor(
    [[[[0., 0.],
        [0., 1.]]],

        [[[0., 0.],
          [1.,  1.]]]])

在進行交叉熵前,首先需要做一個 sigmoid 操作,將數值壓縮到0到1之間:

# 根據二進制交叉熵的計算過程
# 首先進行sigmoid計算,然后與標簽圖像進行二進制交叉熵計算,最后取平均值,即為損失值

# 1. sigmoid
s1 = torch.sigmoid(x1)
s1

'''
out:
tensor([[[[0.6059, 0.4378],
          [0.4207, 0.6660]]],


        [[[0.4280, 0.5915],
          [0.6318, 0.3274]]]]
'''

然后進行交叉熵計算,由於計算的是每個像素的損失值,所以要取個平均值:

# 2.交叉熵計算
loss_cal = -1*(y1*torch.log(s1)+(1-y1)*torch.log(1-s1)) # 此處相當於一個one-hot編碼
loss_cal_mean = torch.mean()
loss_cal_mean

'''
out:
tensor(0.6861)
'''

為了驗證結果,我們使用 pytorch 自帶的二進制交叉熵損失函數計算:

# 使用torch自帶的二進制交叉熵計算
loss_bce = torch.nn.BCELoss()(s1,y1)
loss_bce

'''
out:
tensor(0.6861)
'''

當計算損失值前沒有進行 sigmoid 操作時,pytorch 還提供了包含這個操作的二進制交叉熵損失函數:

# 使用帶sigmoid的二進制交叉熵計算
loss_bce2 = torch.nn.BCEWithLogitsLoss()(x1,y1)
loss_bce2

'''
out:
tensor(0.6861)
'''

可以看到,我們使用了三種方式,計算了交叉熵損失,結果一致。

多通道輸出時的交叉熵損失計算

多通道輸出交叉熵損失計算示意圖

首先,假設我們研究的是一個二分類語義分割問題。

網絡的輸入是一個 2×2 的圖像,設置 batch_size 為 2,網絡輸出多(二)通道特征圖。網絡的標簽也是一個 2 ×2 的二進制掩模圖(即只有0和1的單通道圖像)。

我們在 pytorch 中將其定義:

# 假設輸出一個[batch_size=2, channel=2, height=2, width=2]格式的張量 x1
x1 = torch.tensor([[[[ 0.3164, -0.1922],
          [ 0.4326, -1.2193]],

         [[ 0.6873,  0.6838],
          [ 0.2244,  0.5615]]],


        [[[-0.2516, -0.8875],
          [-0.6289, -0.1796]],

         [[ 0.0411, -1.7851],
          [-0.3069, -1.0379]]]])

# 假設標簽圖像為與x1同型,然后去掉channel的張量 y1 (注意兩點,channel沒了,格式為LongTensor)
y1 = torch.LongTensor([[[0., 1.],
         [1., 0.]],

        [[1., 1.],
         [0., 1.]]])

在進行交叉熵前,首先需要做一個 softmax 操作,將數值壓縮到0到1之間,且使得各通道之間的數值之和為1:

# 1.softmax
s1 = torch.softmax(x1,dim=1)
s1

'''
out:
tensor([[[[0.4083, 0.2940],
          [0.5519, 0.1442]],

         [[0.5917, 0.7060],
          [0.4481, 0.8558]]],


        [[[0.4273, 0.7105],
          [0.4202, 0.7023]],

         [[0.5727, 0.2895],
          [0.5798, 0.2977]]]])
'''

對於標簽圖,由於其張量的形狀與網絡輸出張量不一樣,因此需要做一個one-hot轉換,什么是one-hot?請看這篇博文

# 2.one-hot
y1_one_hot = torch.zeros_like(x1).scatter_(dim=1,index=y1.unsqueeze(dim=1),src=torch.ones_like(x1))
y1_one_hot

'''
out:
tensor([[[[1., 0.],
          [0., 1.]],

         [[0., 1.],
          [1., 0.]]],


        [[[0., 0.],
          [1., 0.]],

         [[1., 1.],
          [0., 1.]]]])
'''

這里需要重點理解這個scatter_函數,他起到的作用十分關鍵,one-hot 轉換時,其實可以理解為將一個同型的全1矩陣中的元素,有選擇性的復制到全0矩陣中的過程,這里的選擇依據就是我們的標簽圖,它決定了哪個位置和通道上的元素取值為 1 。在scatter_ 函數中,dim 決定了用於確定我們在哪個維度上開始定位要建立聯系的元素,index是我們選擇的依據。

按照交叉熵定義,繼續計算:

# 交叉熵計算
loss_cal = -1 *(y1_one_hot * torch.log(s1)) 
loss_cal_mean = loss_cal.sum(dim=1).mean() # 在batch維度下計算每個樣本的交叉熵
loss_cal_mean

'''
out:
tensor(0.9823)
'''

我們也可以使用 pytorch 自帶的交叉熵損失函數計算:

loss_ce = torch.nn.CrossEntropyLoss()(x1,y1)
loss_ce

'''
tensor(0.9823)
'''

可以看到,兩種方式結果一樣。

結論

  • 交叉熵本質上將一群對象擇其一進行研究,自然就變成一個二進制問題,即是這個對象或不是這個對象,然后將標簽與概率融進公式中,計算損失值。對於每一個對象都可以計算一個損失值,求個平均值就是最后這個群體的損失值了。

  • 不論是sigmoid或者softmax,我們都是在有目的將數據規整到0到1之間,從而形成一個概率值,sigmoid針對的是二分類問題,因此算出一個概率,另一個用一減去就到了。多分類問題,由於最后會輸出對應數量的值,softmax 能夠將這些值轉換到0到1,並滿足加起來等於1,這樣的話,當我們只研究其中一個類的概率時,其他類的概率自然就是用1減去這個類的概率了,又回到了二分類問題。

  • 對於二分類語義分割問題,其實采用上述兩種方式都是可以的。

參考資料

[1] pytorch中的 scatter_()函數使用和詳解

[2] pytorch交叉熵使用方法

[3] pytorch損失函數之nn.BCELoss()(為什么用交叉熵作為損失函數)

[4] pytorch損失函數之nn.CrossEntropyLoss()、nn.NLLLoss()

[5] PyTorch中名不符實的損失函數

[6] Pytorch中Softmax、Log_Softmax、NLLLoss以及CrossEntropyLoss的關系與區別詳解

[7] 二分類問題,應該選擇sigmoid還是softmax?


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM