Focal Loss for Dense Object Detection 是ICCV2017的Best student paper,文章思路很簡單但非常具有開拓性意義,效果也非常令人稱贊。
GHM(gradient harmonizing mechanism) 發表於 “Gradient Harmonized Single-stage Detector",AAAI2019,是基於Focal loss的改進,也是個人推薦的一篇深度學習必讀文章。
第一部分 Focal Loss
Focal Loss的引入主要是為了解決難易樣本數量不平衡(注意,有區別於正負樣本數量不平衡)的問題,實際可以使用的范圍非常廣泛,為了方便解釋,還是拿目標檢測的應用場景來說明:
單階段的目標檢測器通常會產生高達100k的候選目標,只有極少數是正樣本,正負樣本數量非常不平衡。我們在計算分類的時候常用的損失——交叉熵的公式如下:
(1)
為了解決正負樣本不平衡的問題,我們通常會在交叉熵損失的前面加上一個參數 ,即:
(2)
但這並不能解決全部問題。根據正、負、難、易,樣本一共可以分為以下四類:
盡管 平衡了正負樣本,但對難易樣本的不平衡沒有任何幫助。而實際上,目標檢測中大量的候選目標都是像下圖一樣的易分樣本。
這些樣本的損失很低,但是由於數量極不平衡,易分樣本的數量相對來講太多,最終主導了總的損失。而本文的作者認為,易分樣本(即,置信度高的樣本)對模型的提升效果非常小,模型應該主要關注與那些難分樣本(這個假設是有問題的,是GHM的主要改進對象)
這時候,Focal Loss就上場了!
一個簡單的思想:把高置信度(p)樣本的損失再降低一些不就好了嗎!
(3)
舉個例, 取2時,如果
,
,損失衰減了1000倍!
Focal Loss的最終形式結合了上面的公式(2). 這很好理解,公式(3)解決了難易樣本的不平衡,公式(2)解決了正負樣本的不平衡,將公式(2)與(3)結合使用,同時解決正負難易2個問題!
最終的Focal Loss形式如下:
實驗表明 取2,
取0.25的時候效果最佳。
這樣以來,訓練過程關注對象的排序為正難>負難>正易>負易。
這就是Focal Loss,簡單明了但特別有用。
def py_sigmoid_focal_loss(pred, target, weight=None, gamma=2.0, alpha=0.25, reduction='mean', avg_factor=None): pred_sigmoid = pred.sigmoid() target = target.type_as(pred) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma) loss = F.binary_cross_entropy_with_logits( pred, target, reduction='none') * focal_weight loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss
這個代碼很容易理解,
先定義一個pt:
然后計算:
focal_weight = (alpha * target + (1 - alpha) *(1 - target)) * pt.pow(gamma)
也就是這個公式:
再把BCE損失*focal_weight就行了
代碼來自於mmdetection\mmdet\models\losses,這個python版的sigmoid_focal_loss實現就是讓你拿去學習的,真正使用的是cuda編程版。真是個人性化的好框架😂
第二部分 GHM
那么,Focal Loss存在什么問題呢?
首先,讓模型過多關注那些特別難分的樣本肯定是存在問題的,樣本中有離群點(outliers),可能模型已經收斂了但是這些離群點還是會被判斷錯誤,讓模型去關注這樣的樣本,怎么可能是最好的呢?
其次, 與
的取值全憑實驗得出,且
和
要聯合起來一起實驗才行(也就是說,
和
的取值會相互影響)。
GHM(gradient harmonizing mechanism) 解決了上述兩個問題。
Focal Loss是從置信度p的角度入手衰減loss,而GHM是一定范圍置信度p的樣本數量的角度衰減loss。
文章先定義了一個梯度模長g:
代碼如下:
g = torch.abs(pred.sigmoid().detach() - target)
其中 是模型預測的概率,
是 ground-truth的標簽,
的取值為0或1.
g正比於檢測的難易程度,g越大則檢測難度越大。
至於為什么叫梯度模長,因為g是從交叉熵損失求梯度得來的:
假定 是樣本的輸出
,我們知道
,
那么 ,可以求出
看下圖梯度模長與樣本數量的關系:
可以看到,梯度模長接近於0的樣本數量最多,隨着梯度模長的增長,樣本數量迅速減少,但是在梯度模長接近於1時,樣本數量也挺多。
GHM的想法是,我們確實不應該過多關注易分樣本,但是特別難分的樣本(outliers,離群點)也不該關注啊!
這些離群點的梯度模長d要比一般的樣本大很多,如果模型被迫去關注這些樣本,反而有可能降低模型的准確度!況且,這些樣本的數量也很多!
那怎么同時衰減易分樣本和特別難分的樣本呢?太簡單了,誰的數量多衰減誰唄!那怎么衰減數量多的呢?簡單啊,定義一個變量,讓這個變量能衡量出一定梯度范圍內的樣本數量——這不就是物理上密度的概念嗎?
於是,作者定義了梯度密度 ——本文最重要的公式:
表明了樣本1~N中,梯度模長分布在
范圍內的樣本個數,
代表了
區間的長度。
因此梯度密度 的物理含義是:單位梯度模長g部分的樣本個數。
接下來就簡單了,對於每個樣本,把交叉熵CE×該樣本梯度密度的倒數即可!
用於分類的GHM損失 , N是總的樣本數量。
梯度密度的詳細計算過程如下:
首先,把梯度模長范圍划分成10個區域,這里要求輸入必須經過sigmoid計算,這樣梯度模長的范圍就限制在0~1之間:
class GHMC(nn.Module): def __init__(self, bins=10, ......): self.bins = bins edges = torch.arange(bins + 1).float() / bins ...... >>> edges = tensor([0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000,0.9000, 1.0000])
edges是每個區域的邊界,有了邊界就很容易計算出梯度模長落入哪個區間內。
然后根據網絡輸出pred和ground true計算loss:
注意,不管是Focal Loss還是GHM其實都是對不同樣本賦予不同的權重,所以該代碼前面計算的都是樣本權重,最后計算GHM Loss就是調用了Pytorch自帶的binary_cross_entropy_with_logits,將樣本權重填進去。
# 計算梯度模長 g = torch.abs(pred.sigmoid().detach() - target) # 目標檢測中很多框被設置為忽略,因此需要額外考慮。 # label_weight=1表示不忽略 label_weight=0表示忽略 valid = label_weight > 0 # 計算所有有效樣本總數 tot = max(valid.float().sum().item(), 1.0) # n 用來統計有效的區間數。 # 假如某個區間沒有落入任何梯度模長,密度為0,需要額外考慮,不然取個倒數就無窮了。 n = 0 # n valid bins # 通過循環計算落入10個bins的梯度模長數量 for i in range(self.bins): inds = (g >= edges[i]) & (g < edges[i + 1]) & valid num_in_bin = inds.sum().item() if num_in_bin > 0: # 重點,所謂的梯度密度就是1/num_in_bin weights[inds] = tot / num_in_bin n += 1 if n > 0: weights = weights / n # 把上面計算的weights填到binary_cross_entropy_with_logits里就行了 loss = torch.nn.functional.binary_cross_entropy_with_logits( pred, target, weights, reduction='sum') / tot
看看抑制的效果吧,也就是文章開頭的這張圖片:
同樣,對於回歸損失:
,其中
為修正的smooth L1 loss.
End~
因為本文着重論文的理解,很多細節沒有寫出,大家還是要去看一下原文的。