MMDetection Sigmoid Focal Loss解析


Focal Loss[1]是一種用來處理單階段目標檢測器訓練過程中出現的正負、難易樣本不平衡問題的方法。關於Focal Loss,[2]中已經講的很詳細了,這篇博客主要是記錄和補充一些細節。

1.兩階段怎么處理樣本數量不平衡的問題

  • 兩階段級聯的檢測方法: 因為物體可能出現在圖片中的任意位置,這些位置構成的集合過於龐大,因此在第一階段使用RPN將可能性大的位置先篩選出來。這一步會過濾掉很多易於檢測的負樣本(easy negatives)
  • 有偏差地進行采樣: 對第一階段剩下的樣本,再按照例如正負樣本1:3的比例進行采樣,這種方法相當於隱式地實現了Focal Loss中的\(\alpha\)參數。

2.Sigmoid Focal Loss

論文中沒有用一般多分類任務采取的softmax loss,而是使用了多標簽分類中的sigmoid loss(即逐個判斷屬於每個類別的概率,不要求所有概率的和為1,一個檢測框可以屬於多個類別),原因是sigmoid的形式訓練過程中會更穩定。因此RetinaNet分類subnet輸出的通道數是 KA 而不是 (K+1)A(K為類別數,A為每個cell鋪的anchor數)。

3.Focal Loss 代碼分析

MMDetection[3]中實現的Focal Loss如下:

# This method is only for debugging
def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean',
                          avg_factor=None):
    
    pred_sigmoid = pred.sigmoid()

    target = target.type_as(pred)
    
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    
    return loss

論文中給出的公式是:\(FL(p_t)=-\alpha_t(1-p_t)^\gamma\log(p_t)\),下面分析代碼的邏輯:

首先給出兩個公式:

\[p_t= \begin{cases} p, &t=1 \\ 1-p, &t=0 \end{cases} \]

\(p_t\)為預測值,表示屬於 t 類別的概率,可統一表示為:\(p_t=p*t+(1-p)*(1-t)\)

\[\alpha_t= \begin{cases} \alpha, &t=1 \\ 1-\alpha, &t=0 \end{cases} \]

\(\alpha_t\)為權重參數,表示屬於 t 類別的權重,可統一表示為:\(\alpha_t=\alpha*t+(1-\alpha)*(1-t)\)

帶入得:\(FL(p_t)=-\alpha_t(1-(p*t+(1-p)*(1-t)))^\gamma\log(p_t)\)

\[=\underbrace{\alpha_t(\overbrace{p(1-t)+t(1-p)}^{pt})^\gamma}_{focal\_weight}*\underbrace{-\log(p_t)}_{cross\ entropy} \]

舉一個例子,設

pred=[0.1, 0.3, 0.8, 0.1, 0.1]
target=[0, 0, 1, 0, 0]
α=0.25
γ=2
# sigmoid value of pred
pred_sigmoid=[0.5250, 0.5744, 0.6900, 0.5250, 0.5250]

直接根據論文的公式計算loss可得:\(FL(pred,target)=\underbrace{-0.75*0.5250^2*\log(1-0.5250)*3 - 0.75*0.5744^2*\log(1-0.5744)}_{negatives}\underbrace{-0.25*(1-0.6900)^2*log(0.6900)}_{positives}=0.1364*5\)
與上面的py_sigmoid_focal_loss函數(計算的是平均值)計算結果相同。

4.關於\(\alpha_t\)

因為Focal Loss的本意是將loss集中在正樣本上,所以我一直以為α=0.25是負樣本的權重,但是調試代碼時發現0.25其實是乘在正樣本上了。這是一個比較矛盾的地方,因為檢測任務中負樣本比正樣本要多很多,而且大部分都是論文中提到過的easy negatives。自然的想法當然是降低這部分loss的權重,讓訓練朝着更有意義的方向進行,所以我們給正樣本的α設大一點,負樣本是1-α,因此會比較小。直到看到[2]評論區的討論,個人覺得還是比較有說服力的:

重新去查了下focal loss論文,在gamma=0時,alpha=0.75效果更好,但當gamma=2時,alpha=0.25效果更好,個人的解釋為負樣本(IOU<=0.5)雖然遠比正樣本(IOU>0.5)要多,但大部分為IOU很小(如<0.1)以至於在gamma作用后某種程度上貢獻較大損失的負樣本甚至比正樣本還要少,所以alpha=0.25要反過來重新平衡負正樣本。

大意就是負樣本大部分都是容易檢測的,用於平衡難易樣本地γ取2時,負樣本的loss會過度地衰減,因此需要α進行反向地平衡。我沒有用代碼驗證過,不過這些都是超參,研究的意義也不大,定性地分析應該足夠。

5.TODO

mmdetection的py_sigmoid_focal_loss實現其實有一點問題,不能直接替換sigmoid_focal_loss,不過最近已經修改過了,這部分以后有機會再細說。

參考

  1. https://arxiv.org/pdf/1708.02002.pdf
  2. https://zhuanlan.zhihu.com/p/80594704
  3. https://github.com/open-mmlab/mmdetection
  4. https://mingming97.github.io/2019/03/29/mmdetection retinanet


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM