焦點損失函數 Focal Loss 與 GHM


文章來自公眾號【機器學習煉丹術】

1 focal loss的概述

焦點損失函數 Focal Loss(2017年何凱明大佬的論文)被提出用於密集物體檢測任務。

當然,在目標檢測中,可能待檢測物體有1000個類別,然而你想要識別出來的物體,只是其中的某一個類別,這樣其實就是一個樣本非常不均衡的一個分類問題。

而Focal Loss簡單的說,就是解決樣本數量極度不平衡的問題的。

說到樣本不平衡的解決方案,相比大家是知道一個混淆矩陣的f1-score的,但是這個好像不能用在訓練中當成損失。而Focal loss可以在訓練中,讓小數量的目標類別增加權重,讓分類錯誤的樣本增加權重

先來看一下簡單的二值交叉熵的損失:

  • y’是模型給出的預測類別概率,y是真實樣本。就是說,如果一個樣本的真實類別是1,預測概率是0.9,那么\(-log(0.9)\)就是這個損失。
  • 講道理,一般我不喜歡用二值交叉熵做例子,用多分類交叉熵做例子會更舒服。

【然后看focal loss的改進】:

這個增加了一個\((1-y')^\gamma\)的權重值,怎么理解呢?就是如果給出的正確類別的概率越大,那么\((1-y')^\gamma\)就會越小,說明分類正確的樣本的損失權重小,反之,分類錯誤的樣本的損權重大


【focal loss的進一步改進】:

這里增加了一個\(\alpha\),這個alpha在論文中給出的是0.25,這個就是單純的降低正樣本或者負樣本的權重,來解決樣本不均衡的問題

兩者結合起來,就是一個可以解決樣本不平衡問題的損失focal loss。


【總結】:

  1. \(\alpha\)解決了樣本的不平衡問題;
  2. \(\beta\)解決了難易樣本不平衡的問題。讓樣本更重視難樣本,忽視易樣本。
  3. 總之,Focal loss會的關注順序為:樣本少的、難分類的;樣本多的、難分類的;樣本少的,易分類的;樣本多的,易分類的。

2 GHM

  • GHM是Gradient Harmonizing Mechanism。

這個GHM是為了解決Focal loss存在的一些問題。

【Focal Loss的弊端1】
讓模型過多的關注特別難分類的樣本是會有問題的。樣本中有一些異常點、離群點(outliers)。所以模型為了擬合這些非常難擬合的離群點,就會存在過擬合的風險。

2.1 GHM的辦法

Focal Loss是從置信度p的角度入手衰減loss的。而GHM是一定范圍內置信度p的樣本數量來衰減loss的。

首先定義了一個變量g,叫做梯度模長(gradient norm)

可以看出這個梯度模長,其實就是模型給出的置信度\(p^*\)與這個樣本真實的標簽之間的差值(距離)。g越小,說明預測越准,說明樣本越容易分類。

下圖中展示了g與樣本數量的關系:

【從圖中可以看到】

  • 梯度模長接近於0的樣本多,也就是易分類樣本是非常多的
  • 然后樣本數量隨着梯度模長的增加迅速減少
  • 然后當梯度模長接近1的時候,樣本的數量又開始增加。

GHM是這樣想的,對於梯度模長小的易分類樣本,我們忽視他們;但是focal loss過於關注難分類樣本了。關鍵是難分類樣本其實也有很多!,如果模型一直學習難分類樣本,那么可能模型的精確度就會下降。所以GHM對於難分類樣本也有一個衰減。

那么,GHM對易分類樣本和難分類樣本都衰減,那么真正被關注的樣本,就是那些不難不易的樣本。而抑制的程度,可以根據樣本的數量來決定。

這里定義一個GD,梯度密度

\[GD(g)=\frac{1}{l(g)}\sum_{k=1}^N{\delta(g_k,g)} \]

  • \(GD(g)\)是計算在梯度g位置的梯度密度;
  • \(\delta(g_k,g)\)就是樣本k的梯度\(g_k\)是否在\([g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]\)這個區間內。
  • \(l(g)\)就是\([g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]\)這個區間的長度,也就是\(\epsilon\)

總之,\(GD(g)\)就是梯度模長在\([g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]\)內的樣本總數除以\(\epsilon\).

然后把每一個樣本的交叉熵損失除以他們對應的梯度密度就行了。

\[L_{GHM}=\sum^N_{i=1}{\frac{CE(p_i,p_i^*)}{GD(g_i)}} \]

  • \(CE(p_i,p_i^*)\)表示第i個樣本的交叉熵損失;
  • \(GD(g_i)\)表示第i個樣本的梯度密度;

2.2 論文中的GHM

論文中呢,是把梯度模長划分成了10個區域,因為置信度p是從0~1的,所以梯度密度的區域長度就是0.1,比如是0~0.1為一個區域。

下圖是論文中給出的對比圖:

【從圖中可以得到】

  • 綠色的表示交叉熵損失;
  • 藍色的是focal loss的損失,發現梯度模長小的損失衰減很有效;
  • 紅色是GHM的交叉熵損失,發現梯度模長在0附近和1附近存在明顯的衰減。

當然可以想到的是,GHM看起來是需要整個樣本的模型估計值,才能計算出梯度密度,才能進行更新。也就是說mini-batch看起來似乎不能用GHM。

在GHM原文中也提到了這個問題,如果光使用mini-batch的話,那么很可能出現不均衡的情況。

【我個人覺得的處理方法】

  1. 可以使用上一個epoch的梯度密度,來作為這一個epoch來使用;
  2. 或者一開始先使用mini-batch計算梯度密度,然后模型收斂速度下降之后,再使用第一種方式進行更新。

3 python實現

上面講述的關鍵在於focal loss實現的功能:

  1. 分類正確的樣本的損失權重小,分類錯誤的樣本的損權重大
  2. 樣本過多的類別的權重較小

在CenterNet中預測中心點位置的時候,也是使用了Focal Loss,但是稍有改動。

3.1 概述


這里面和上面講的比較類似,我們忽視腳標。

  • 假設\(Y=1\),那么預測的\(\hat{Y}\)越靠近1,說明預測的約正確,然后\((1-\hat{Y})^\alpha\)就會越小,從而體現分類正確的樣本的損失權重小;otherwize的情況也是這樣。
  • 但是這里的otherwize中多了一個\((1-Y)^\beta\),這個是用來平衡樣本不均衡問題的,在后面的代碼部分會提到CenterNet的熱力圖。就會明白這個了。

3.2 代碼講解

下面通過代碼來理解:

class FocalLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.neg_loss = _neg_loss

    def forward(self, output, target, mask):
        output = torch.sigmoid(output)
        loss = self.neg_loss(output, target, mask)
        return loss

這里面的output可以理解為是一個1通道的特征圖,每一個pixel的值都是模型給出的置信度,然后通過sigmoid函數轉換成0~1區間的置信度。

而target是CenterNet的熱力圖,這一點可能比較難理解。打個比方,一個10*10的全都是0的特征圖,然后這個特征圖中只有一個pixel是1,那么這個pixel的位置就是一個目標檢測物體的中心點。有幾個1就說明這個圖中有幾個要檢測的目標物體。

然后,如果一個特征圖上,全都是0,只有幾個孤零零的1,未免顯得過於稀疏了,直觀上也非常的不平滑。所以CenterNet的熱力圖還需要對這些1為中心做一個高斯

可以看作是一種平滑:

可以看到,數字1的四周是同樣的數字。這是一個以1為中心的高斯平滑。


這里我們回到上面說到的\((1-Y)^\beta\)

對於數字1來說,我們計算loss自然是用第一行來計算,但是對於1附近的其他點來說,就要考慮\((1-Y)^\beta\)了。越靠近1的點的\(Y\)越大,那么\((1-Y)^\beta\)就會越小,這樣從而降低1附近的權重值。其實這里我也講不太明白,就是根據距離1的距離降低負樣本的權重值,從而可以實現樣本過多的類別的權重較小


我們回到主題,對output進行sigmoid之后,與output一起放到了neg_loss中。我們來看什么是neg_loss:

def _neg_loss(pred, gt, mask):
    pos_inds = gt.eq(1).float() * mask
    neg_inds = gt.lt(1).float() * mask

    neg_weights = torch.pow(1 - gt, 4)

    loss = 0

    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * \
               neg_weights * neg_inds

    num_pos = pos_inds.float().sum()
    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    if num_pos == 0:
        loss = loss - neg_loss
    else:
        loss = loss - (pos_loss + neg_loss) / num_pos
    return loss

先說一下,這里面的mask是根據特定任務中加上的一個小功能,就是在該任務中,一張圖片中有一部分是不需要計算loss的,所以先用過mask把那個部分過濾掉。這里直接忽視mask就好了。

neg_weights = torch.pow(1 - gt, 4)可以得知\(\beta=4\),從下面的代碼中也不難推出,\(\alpha=2\),剩下的內容就都一樣了。

把每一個pixel的損失都加起來,除以目標物體的數量即可。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM