Focal Loss[1]是一種用來處理單階段目標檢測器訓練過程中出現的正負、難易樣本不平衡問題的方法。關於Focal Loss,[2]中已經講的很詳細了,這篇博客主要是記錄和補充一些細節。
1.兩階段怎么處理樣本數量不平衡的問題
- 兩階段級聯的檢測方法: 因為物體可能出現在圖片中的任意位置,這些位置構成的集合過於龐大,因此在第一階段使用RPN將可能性大的位置先篩選出來。這一步會過濾掉很多易於檢測的負樣本(easy negatives)。
- 有偏差地進行采樣: 對第一階段剩下的樣本,再按照例如正負樣本1:3的比例進行采樣,這種方法相當於隱式地實現了Focal Loss中的\(\alpha\)參數。
2.Sigmoid Focal Loss
論文中沒有用一般多分類任務采取的softmax loss,而是使用了多標簽分類中的sigmoid loss(即逐個判斷屬於每個類別的概率,不要求所有概率的和為1,一個檢測框可以屬於多個類別),原因是sigmoid的形式訓練過程中會更穩定。因此RetinaNet分類subnet輸出的通道數是 KA 而不是 (K+1)A(K為類別數,A為每個cell鋪的anchor數)。
3.Focal Loss 代碼分析
MMDetection[3]中實現的Focal Loss如下:
# This method is only for debugging
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
論文中給出的公式是:\(FL(p_t)=-\alpha_t(1-p_t)^\gamma\log(p_t)\),下面分析代碼的邏輯:
首先給出兩個公式:
\(p_t\)為預測值,表示屬於 t 類別的概率,可統一表示為:\(p_t=p*t+(1-p)*(1-t)\)
\(\alpha_t\)為權重參數,表示屬於 t 類別的權重,可統一表示為:\(\alpha_t=\alpha*t+(1-\alpha)*(1-t)\)
帶入得:\(FL(p_t)=-\alpha_t(1-(p*t+(1-p)*(1-t)))^\gamma\log(p_t)\)
舉一個例子,設
pred=[0.1, 0.3, 0.8, 0.1, 0.1]
target=[0, 0, 1, 0, 0]
α=0.25
γ=2
# sigmoid value of pred
pred_sigmoid=[0.5250, 0.5744, 0.6900, 0.5250, 0.5250]
直接根據論文的公式計算loss可得:\(FL(pred,target)=\underbrace{-0.75*0.5250^2*\log(1-0.5250)*3 - 0.75*0.5744^2*\log(1-0.5744)}_{negatives}\underbrace{-0.25*(1-0.6900)^2*log(0.6900)}_{positives}=0.1364*5\)
與上面的py_sigmoid_focal_loss
函數(計算的是平均值)計算結果相同。
4.關於\(\alpha_t\)
因為Focal Loss的本意是將loss集中在正樣本上,所以我一直以為α=0.25是負樣本的權重,但是調試代碼時發現0.25其實是乘在正樣本上了。這是一個比較矛盾的地方,因為檢測任務中負樣本比正樣本要多很多,而且大部分都是論文中提到過的easy negatives。自然的想法當然是降低這部分loss的權重,讓訓練朝着更有意義的方向進行,所以我們給正樣本的α設大一點,負樣本是1-α,因此會比較小。直到看到[2]評論區的討論,個人覺得還是比較有說服力的:
重新去查了下focal loss論文,在gamma=0時,alpha=0.75效果更好,但當gamma=2時,alpha=0.25效果更好,個人的解釋為負樣本(IOU<=0.5)雖然遠比正樣本(IOU>0.5)要多,但大部分為IOU很小(如<0.1)以至於在gamma作用后某種程度上貢獻較大損失的負樣本甚至比正樣本還要少,所以alpha=0.25要反過來重新平衡負正樣本。
大意就是負樣本大部分都是容易檢測的,用於平衡難易樣本地γ取2時,負樣本的loss會過度地衰減,因此需要α進行反向地平衡。我沒有用代碼驗證過,不過這些都是超參,研究的意義也不大,定性地分析應該足夠。
5.TODO
mmdetection的py_sigmoid_focal_loss
實現其實有一點問題,不能直接替換sigmoid_focal_loss
,不過最近已經修改過了,這部分以后有機會再細說。