0 前言
Focal Loss是為了處理樣本不平衡問題而提出的,經時間驗證,在多種任務上,效果還是不錯的。在理解Focal Loss前,需要先深刻理一下交叉熵損失,和帶權重的交叉熵損失。然后我們從樣本權重的角度出發,理解Focal Loss是如何分配樣本權重的。Focal是動詞Focus的形容詞形式,那么它究竟Focus在什么地方呢?(詳細的代碼請看Gitee)。
1 交叉熵
1.1 交叉熵損失(Cross Entropy Loss)
有\(N\)個樣本,輸入一個\(C\)分類器,得到的輸出為\(X\in \mathcal{R}^{N\times C}\),它共有\(C\)類;其中某個樣本的輸出記為\(x\in \mathcal{R}^{1\times C}\),即\(x[j]\)是\(X\)的某個行向量,那么某個交叉熵損失可以寫為如下公式:
其中\(\text{class}\in [0,\ C)\)是這個樣本的類標簽,如果給出了類標簽的權重向量\(W\in \mathcal{R}^{1\times C}\),那么帶權重的交叉熵損失可以更改為如下公式:
最終對這個\(N\)個樣本的損失求和或者求平均:
這個就是我們平時經常用到的交叉熵損失了。
1.2 二分類交叉熵損失(Binary Cross Entropy Loss)
上面所提到的交叉熵損失是適用於多分類(二分類及以上)的,但是它的公式看起來似乎與我們平時在書上或論文中看到的不一樣,一般我們常見的交叉熵損失公式如下:
這是一個典型的二分類交叉熵損失,其中\(y\in\{0,\ 1\}\)表示標簽值,\(\hat{y}\in[0,\ 1]\)表示分類模型的類別1預測值。上面這個公式是一個綜合的公式,它等價於:
其中\(\hat{y}_0, \hat{y}_1\)是二分類模型輸出的2個偽概率值。
例:如果二分類模型是神經網絡,且最后一層為: 2個神經元+Softmax,那么\(\hat{y}_0, \hat{y}_1\)就對應着這兩個神經元的輸出值。當然它也可以帶上類別的權重。
同樣地,有\(N\)個樣本,輸入一個2分類器,得到的輸出為\(X\in \mathcal{R}^{N\times 2}\),再經過Softmax函數,\(\hat{Y}=\sigma(X)\in \mathcal{R}^{N\times 2}\),標簽為\(Y\in \mathcal{R}^{N\times 2}\),每個樣本的二分類損失記為\(l^{(i)}, i=0,1,2,\cdots,N\),最終對這個\(N\)個樣本的損失求和或者求平均:
注:如果一次只訓練一個樣本,即\(N=1\),那么上面帶類別權重的損失中的權重是無效的。因為
權重是相對的,某一個樣本的權重大,那么必然需要有另一個樣本的權重小,這樣才能體現出這一批樣本中某些樣本的重要性。\(N=1\)時,已沒有權重的概念,它是唯一的,也是最重要的。\(N=1\),或者說batch_size=1這種情況在訓練視頻\文章數據時,是會常出現的。由於我們顯示/內存的限制,而視頻/文章數據又比較大,一次只能訓練一個樣本,此時我們就需要注意權重的問題了。
2 Focal Loss
2.1 基本思想
一般來講,Focal Loss(以下簡稱FL)[1]是為解決樣本不平衡的問題,但是更准確地講,它是為解決難分類樣本(Hard Example)和易分類樣本(Easy Example)的不平衡問題。對於樣本不平衡,其實通過上面的帶權重的交叉熵損失便可以一定程度上解決這個問題,但是在實際問題中,以權重來解決樣本不平衡問題的效果不夠理想,此時我們應當思考,表面上我們的樣本不平衡,但實質上導致效果不好的原因也許並不是簡單地因為樣本不平衡,而是因為樣本中存在一些Hard Example,同時存在許多Easy Example,Easy Example雖然容易被分類器分辨,損失較小,但是由於其數量大,它們累積起來依然於大於Hard Example的Loss值,因此我們需要給Hard Example較大的權重,而Easy Example較小的權重。
那么什么叫Hard Example,什么叫Easy Example呢?看下面的圖就知道了。
|
|
|
|
| 圖2-1 Hard Example | 圖2-2 Easy Example1 | 圖2-3 Easy Example2 | 圖2-4 Example Space |
假設,我們的任務是訓練一個分類器,分類出人和馬,對於上面的三張圖,圖2-2和圖2-3應該是非常容易判斷出來的,但是圖2-1就是不那么容易了,它即有人的特征,又有馬的特征,非常容易混淆。這種樣本雖然在數據集中出現的頻率可能並不高,但是想要提高分類器的性能,需要着力解決這種樣本分類問題。
提出Hard Example和Easy Example后,可以將樣本空間划分為如圖2-4所示的樣本空間。其中縱軸為多數類樣本(Majority Class)和少數類樣本(Minority Class),上面的帶權重的交叉熵損失只能解決Majority Class和Minority Class的樣本不平衡問題,並沒有考慮Hard Example和Easy Example的問題,Focal Loss的提出就是為解決這個難易樣本的分類問題。
2.2 Focal Loss解決方案
要解決難易樣本的分類問題,首先就需要找出Hard Example和Easy Example。這對於神經網絡來說,應該是一件比較容易的事情。如圖2-6所示,這是一個5分類的網絡,神經網絡的最后一層輸出時,加上一個Softmax或者Sigmoid就會得到輸出的偽概率值,代表着模型預測的每個類別的概率,
|
|
| 圖2-6 Easy Example Classifier Output | 圖2-7 Hard Example Classifier Output |
圖2-6中,樣本標簽為1,分類器輸出值最大的為第1個神經元(以0開始計數),這剛好預測准確,而且其輸出值2也比其它神經元的輸出值要大不少,因此可以認為這是一個易分類樣本(Easy Example);圖2-7的樣本標簽是3,分類器輸出值最大的為第4個神經元,並且這幾個神經元的輸出值都相差不大,神經網絡無法准確判斷這個樣本的類別,所以可以認為這是一個難分類樣本(Hard Example)。其實說白了,判斷Easy/Hard Example的方法就是看分類網絡的最后的輸出值。如果網絡預測准確,且其概率較大,那么這是一個Easy Example,如果網絡輸出的概率較小,這是一個Hard Example。下面用數學公式嚴謹地表達來Focal Loss的表達式。
令一個\(C\)類分類器的輸出為\(\boldsymbol{y}\in \mathcal{R}^{C\times 1}\),定義函數\(f\)將輸出\(\boldsymbol{y}\)轉為偽概率值\(\boldsymbol{p}=f(\boldsymbol{y})\),當前樣本的類標簽為\(t\),記\(p_t=\boldsymbol{p}[t]\),它表示分類器預測為\(t\)類的概率值,再結合上面的交叉熵損失,定義Focal Loss為:
這實質就是交叉熵損失前加了一個權重,只不過這個權重有點不一樣的來頭。為了更好地控制前面權重的大小,可以給前面的權重系數添加一個指數\(\gamma\),那么更改式(2-1):
其中\(\gamma\)一值取值為2就好,\(\gamma\)取值為0時與交叉熵損失等價,\(\gamma\)越大,就越抑制Easy Example的損失,相對就會越放大Hard Example的損失。同時為解決樣本類別不平衡的問題,可以再給式(2-2)添加一個類別的權重\(\alpha_t\)(這個類別權重上面的交叉熵損失已經實現):
到這里,Focal Loss理論就結束了,非常簡單,但是有效。
3 Focal Loss實現(Pytorch)
3.1 交叉熵損失實現(numpy)
為了更好的理解Focal Loss的實現,先理解交叉熵損失的實現,我這里用numpy簡單地實現了一下交叉熵損失。
import numpy as np
def cross_entropy(output, target):
out_exp = np.exp(output)
out_cls = np.array([out_exp[i, t] for i, t in enumerate(target)])
ce = -np.log(out_cls / out_exp.sum(1))
return ce
代碼中第5行,可能稍微有點難以理解,它不過是為了找出標簽對應的輸出值。比如第2個樣本的標簽值為3,那它分類器的輸出應當選擇第2行,第3列的值。
3.2 Focal Loss實現
下面的代碼的1012行:依據輸出,計算概率,再將其轉為`focal_weight`;1516行,將類權重和focal_weight添加到交叉熵損失,得到最終的focal_loss;18~21行,實現mean和sum兩種reduction方法,注意求平均不是簡單的直接平均,而是加權平均。
class FocalLoss(nn.Module):
def __init__(self, gamma=2, weight=None, reduction='mean'):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
self.reduction = reduction
def forward(self, output, target):
# convert output to pseudo probability
out_target = torch.stack([output[i, t] for i, t in enumerate(target)])
probs = torch.sigmoid(out_target)
focal_weight = torch.pow(1-probs, self.gamma)
# add focal weight to cross entropy
ce_loss = F.cross_entropy(output, target, weight=self.weight, reduction='none')
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
focal_loss = (focal_loss/focal_weight.sum()).sum()
elif self.reduction == 'sum':
focal_loss = focal_loss.sum()
return focal_loss
注:上面實現中,output的維度應當滿足
output.dim==2,並且其形狀為(batch_size, C),且target.max()<C。
總結
Focal Loss從2017年提出至今,該論文已有2000多引用,足以說明其有效性。其實從本質上講,它也只不過是給樣本重新分配權重,它相對類別權重的分配方法,只不過是將樣本空間進行更為細致的划分,從圖2-4很容易理解,類別權重的方法,只是將樣本空間划分為藍色線上下兩個部分,而加入難易樣本的划分,又可以將空間划分為左右兩個部分,如此,樣本空間便被划分4個部分,這樣更加細致。其實借助於這個思想,我們是否可以根據不同任務的需求,更加細致划分我們的樣本空間,然后再相應的分配不同的權重呢?
