5 分鍾理解 Focal Loss 與 GHM——解決樣本不平衡利器


Focal Loss for Dense Object Detection 是ICCV2017的Best student paper,文章思路很簡單但非常具有開拓性意義,效果也非常令人稱贊。

GHM(gradient harmonizing mechanism) 發表於 “Gradient Harmonized Single-stage Detector",AAAI2019,是基於Focal loss的改進,也是個人推薦的一篇深度學習必讀文章。

第一部分 Focal Loss

Focal Loss的引入主要是為了解決難易樣本數量不平衡(注意,有區別於正負樣本數量不平衡)的問題,實際可以使用的范圍非常廣泛,為了方便解釋,還是拿目標檢測的應用場景來說明:

單階段的目標檢測器通常會產生高達100k的候選目標,只有極少數是正樣本,正負樣本數量非常不平衡。我們在計算分類的時候常用的損失——交叉熵的公式如下:

[公式] (1)

為了解決正負樣本不平衡的問題,我們通常會在交叉熵損失的前面加上一個參數 [公式] ,即:

[公式] (2)

但這並不能解決全部問題。根據正、負、難、易,樣本一共可以分為以下四類:

file

盡管 [公式]平衡了正負樣本,但對難易樣本的不平衡沒有任何幫助。而實際上,目標檢測中大量的候選目標都是像下圖一樣的易分樣本。

這些樣本的損失很低,但是由於數量極不平衡,易分樣本的數量相對來講太多,最終主導了總的損失。而本文的作者認為,易分樣本(即,置信度高的樣本)對模型的提升效果非常小,模型應該主要關注與那些難分樣本這個假設是有問題的,是GHM的主要改進對象

這時候,Focal Loss就上場了!

一個簡單的思想:把高置信度(p)樣本的損失再降低一些不就好了嗎!

[公式] (3)

舉個例, [公式] 取2時,如果 [公式] , [公式] ,損失衰減了1000倍!

Focal Loss的最終形式結合了上面的公式(2). 這很好理解,公式(3)解決了難易樣本的不平衡,公式(2)解決了正負樣本的不平衡,將公式(2)與(3)結合使用,同時解決正負難易2個問題!

最終的Focal Loss形式如下:

[公式]

實驗表明[公式] 取2, [公式] 取0.25的時候效果最佳。

file
這樣以來,訓練過程關注對象的排序為正難>負難>正易>負易。

這就是Focal Loss,簡單明了但特別有用。

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean',
                          avg_factor=None):
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

  

這個代碼很容易理解,

先定義一個pt:

[公式]

然后計算:

focal_weight = (alpha * target + (1 - alpha) *(1 - target)) * pt.pow(gamma)

  

也就是這個公式:

[公式]

再把BCE損失*focal_weight就行了

[公式]

代碼來自於mmdetection\mmdet\models\losses,這個python版的sigmoid_focal_loss實現就是讓你拿去學習的,真正使用的是cuda編程版。真是個人性化的好框架😂

 

第二部分 GHM

 

那么,Focal Loss存在什么問題呢?

首先,讓模型過多關注那些特別難分的樣本肯定是存在問題的,樣本中有離群點(outliers),可能模型已經收斂了但是這些離群點還是會被判斷錯誤,讓模型去關注這樣的樣本,怎么可能是最好的呢?

 

其次 [公式] 與 [公式] 的取值全憑實驗得出,且 [公式] 和 [公式] 要聯合起來一起實驗才行(也就是說, [公式] 和 [公式] 的取值會相互影響)。

GHM(gradient harmonizing mechanism) 解決了上述兩個問題。

Focal Loss是從置信度p的角度入手衰減loss,而GHM是一定范圍置信度p的樣本數量的角度衰減loss。

文章先定義了一個梯度模長g

[公式]

代碼如下:

 

g = torch.abs(pred.sigmoid().detach() - target)

  

其中 [公式] 是模型預測的概率,[公式]是 ground-truth的標簽, [公式] 的取值為0或1.

g正比於檢測的難易程度,g越大則檢測難度越大。

至於為什么叫梯度模長,因為g是從交叉熵損失求梯度得來的:

[公式]

假定 [公式] 是樣本的輸出 [公式] ,我們知道 [公式] ,

那么 [公式] ,可以求出

[公式][公式]

[公式]

看下圖梯度模長與樣本數量的關系:

file

可以看到,梯度模長接近於0的樣本數量最多,隨着梯度模長的增長,樣本數量迅速減少,但是在梯度模長接近於1時,樣本數量也挺多。

GHM的想法是,我們確實不應該過多關注易分樣本,但是特別難分的樣本(outliers,離群點)也不該關注啊!

這些離群點的梯度模長d要比一般的樣本大很多,如果模型被迫去關注這些樣本,反而有可能降低模型的准確度!況且,這些樣本的數量也很多!

那怎么同時衰減易分樣本和特別難分的樣本呢?太簡單了,誰的數量多衰減誰唄!那怎么衰減數量多的呢?簡單啊,定義一個變量,讓這個變量能衡量出一定梯度范圍內的樣本數量——這不就是物理上密度的概念嗎?

於是,作者定義了梯度密度 [公式]——本文最重要的公式

[公式]

[公式] 表明了樣本1~N中,梯度模長分布在 [公式] 范圍內的樣本個數, [公式] 代表了 [公式] 區間的長度。

因此梯度密度 [公式]的物理含義是:單位梯度模長g部分的樣本個數。

接下來就簡單了,對於每個樣本,把交叉熵CE×該樣本梯度密度的倒數即可!

用於分類的GHM損失 [公式][公式] , N是總的樣本數量。

梯度密度的詳細計算過程如下:

首先,把梯度模長范圍划分成10個區域,這里要求輸入必須經過sigmoid計算,這樣梯度模長的范圍就限制在0~1之間:

class GHMC(nn.Module):
    def __init__(self, bins=10, ......):
        self.bins = bins
        edges = torch.arange(bins + 1).float() / bins
......

>>> edges = tensor([0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 
                  0.5000, 0.6000, 0.7000, 0.8000,0.9000, 1.0000])

  

edges是每個區域的邊界,有了邊界就很容易計算出梯度模長落入哪個區間內。

然后根據網絡輸出pred和ground true計算loss:

注意,不管是Focal Loss還是GHM其實都是對不同樣本賦予不同的權重,所以該代碼前面計算的都是樣本權重,最后計算GHM Loss就是調用了Pytorch自帶的binary_cross_entropy_with_logits,將樣本權重填進去。

   # 計算梯度模長
        g = torch.abs(pred.sigmoid().detach() - target)
        # 目標檢測中很多框被設置為忽略,因此需要額外考慮。
        # label_weight=1表示不忽略 label_weight=0表示忽略
        valid = label_weight > 0
        # 計算所有有效樣本總數
        tot = max(valid.float().sum().item(), 1.0)
        # n 用來統計有效的區間數。
        # 假如某個區間沒有落入任何梯度模長,密度為0,需要額外考慮,不然取個倒數就無窮了。
        n = 0  # n valid bins 
        # 通過循環計算落入10個bins的梯度模長數量
        for i in range(self.bins):
            inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
            num_in_bin = inds.sum().item()
            if num_in_bin > 0:
                # 重點,所謂的梯度密度就是1/num_in_bin
                weights[inds] = tot / num_in_bin
                n += 1
        if n > 0:
            weights = weights / n
        # 把上面計算的weights填到binary_cross_entropy_with_logits里就行了
        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            pred, target, weights, reduction='sum') / tot

  

看看抑制的效果吧,也就是文章開頭的這張圖片:

file

同樣,對於回歸損失:

[公式] ,其中 [公式] 為修正的smooth L1 loss.

End~

因為本文着重論文的理解,很多細節沒有寫出,大家還是要去看一下原文的。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM