論文鏈接:Focal loss for dense object detection
總體上講,Focal Loss是一個緩解分類問題中類別不平衡、難易樣本不均衡的損失函數。首先看一下論文中的這張圖:
解釋:
- 橫軸是ground truth類別對應的概率(經過sigmoid/softmax處理過的logits),縱軸是對應的loss值;
- 藍色的線(gamma=0),就是原始交叉熵損失函數,可以明顯看出ground truth的概率越大,loss越小,符合常識;
- 除了藍色的線,其他幾個都是Focal Loss的線,其實原始交叉熵損失函數是Focal Loss的特殊版本(gamma=0)
- 其他幾個Focal Loss線都在藍色下邊,可以看出Focal Loss的作用就是【衰減】;
- 從圖中可以看出,ground truth的概率越大(即容易分類的簡單樣本),衰減越厲害,也就是大大降低了簡單樣本的loss;
- 從圖中可以看出,ground truth的概率越小(即不易分類的困難樣本),也是有衰減的,但是衰減的程度比較小;
下邊是我自己模擬的一組數據,一組固定的logits=[0+epsilon, 0.1, 0.2, ..., 0.9, 1.0-epsilon],然后假設ground truth分別是0、1、2、...、9、10的時候,gamma=0、0.5、1、2、...、8、16對應的loss。
例如第3行第1列的2.75表示,ground truth是類別2,即對應的logits是0.2,gamma=0的時候,loss=2.75(gamma=0,就是原始的多分類交叉熵)。
根據上表可以得到下邊的圖:
從上圖可以看出,隨着gamma增大,整體loss都下降了,但是logits相對越高(這個例子中最大logits=1),下降的倍數越大。從上表的最后一列也可以看出來,gamma=0和gamma=16的時候,logits=0只衰減了2倍,但是logits=1衰減了16倍。
因為論文中沒有給出比較官方的focal loss實現,所以網上focal loss有很多實現版本。有以下幾個判斷標准:
- 當gamma為0的時候,等同於原始交叉熵損失;
- 二分類版本需要同時考慮正負樣本的影響,多分類版本只需要考慮true label的影響,因為softmax的時候,已經考慮了其他labels;
- 多分類版本因為每個樣本其實只需要1個值(即y_true one-hot向量中值為1的那個),所以有些實現會用tf.gather簡化計算;
二分類Focal Loss
二分類交叉熵損失函數
其中,y是ground truth 類別,p是模型預測樣本類別為1的概率(則1-p是樣本類別為0的概率)。
為了簡化公式,用pt表示概率:
所以二分類交叉熵公式就是:
為了處理類別不均衡問題,我們可以給二分類交叉熵公式加上一個alpha參數,實際應用中,alpha通常會根據逆類別頻率或者當作超參數根據交叉驗證得到:
二分類Focal Loss
上邊引入了alpha參數可以緩解類別不均衡問題,但是無法處理難易樣本不均衡問題。為了處理難易樣本不均衡的問題,可以引入一個調節因子(1-pt)gamma,例如gamma=2,則調節因子就是(1-pt)2。這個調節因子是個小於1的,所以可以起到衰減的作用,而且pt越接近1(模型置信度越高,說明樣本越簡單),衰減的越厲害。
當然,我們也可以給這個損失函數再加上alpha,在原論文的實驗中,這個會有一些提升。
二分類Focal Loss的Tensorflow實現
需要注意的地方:
- 要知道公式中的pt是類別對應的probs,而不是logits(logits經過sigmoid/softmax變成probs);
- 很多代碼中都用y_pred變量,自己要搞清楚y_pred是指logits還是probs;
- 二分類的p_t是要同時計算正/負樣本的,這里和多分類有區別;
下邊的代碼參考了這里【p.s. 這篇文章的多分類Focal Loss可能有問題?gamma=0時不等同原始交叉熵損失。】,但是也做了些調整。
def binary_focal_loss(gamma=2, alpha=0.25):
alpha = tf.constant(alpha, dtype=tf.float32)
gamma = tf.constant(gamma, dtype=tf.float32)
def binary_focal_loss_fixed(n_classes, logits, true_label):
epsilon = 1.e-8
# 得到y_true和y_pred
y_true = tf.one_hot(true_label, n_classes)
probs = tf.nn.sigmoid(logits)
y_pred = tf.clip_by_value(probs, epsilon, 1. - epsilon)
# 得到調節因子weight和alpha
## 先得到y_true和1-y_true的概率【這里是正負樣本的概率都要計算哦!】
p_t = y_true * y_pred \
+ (tf.ones_like(y_true) - y_true) * (tf.ones_like(y_true) - y_pred)
## 然后通過p_t和gamma得到weight
weight = tf.pow((tf.ones_like(y_true) - p_t), gamma)
## 再得到alpha,y_true的是alpha,那么1-y_true的是1-alpha
alpha_t = y_true * alpha + (tf.ones_like(y_true) - y_true) * (1 - alpha)
# 最后就是論文中的公式,相當於:- alpha * (1-p_t)^gamma * log(p_t)
focal_loss = - alpha_t * weight * tf.log(p_t)
return tf.reduce_mean(focal_loss)
多分類Focal Loss
多分類交叉熵損失函數
首先看一下多分類的交叉熵損失函數:
其中y_i為第i個類別對應的真實標簽(一個one-hot向量,只有第i個位置為1),f_i(x)為對應的模型輸出值,也就是p_t,也就是經過softmax處理過的logits。直觀的解釋就是:對於每個樣本,從p_t數組中選擇第i個數取對數,再乘-1,就是這個樣本的loss了,所以y_i one-hot向量就是起一個選擇的作用,為1,即選擇,為0,即不選。
多分類Focal Loss
從公式上看,多分類Focal Loss和二分類Focal Loss沒啥區別,也是加上一個調節因子weight=(1-pt)^gamma和alpha。
多分類Focal Loss的Tensorflow實現
首先看一下多分類交叉熵損失函數的實現
def test_softmax_cross_entropy_with_logits(n_classes, logits, true_label):
epsilon = 1.e-8
# 得到y_true和y_pred
y_true = tf.one_hot(true_label, n_classes)
softmax_prob = tf.nn.softmax(logits)
y_pred = tf.clip_by_value(softmax_prob, epsilon, 1. - epsilon)
# 得到交叉熵,其中的“-”符號可以放在好幾個地方,都是等效的,最后取mean是為了兼容batch訓練的情況。
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_true*tf.log(y_pred)))
return cross_entropy
所以需要做的就是往上邊這段代碼中加入gamma和alpha參數:
def test_softmax_focal_ce_3(n_classes, gamma, alpha, logits, label):
epsilon = 1.e-8
# y_true and y_pred
y_true = tf.one_hot(label, n_classes)
probs = tf.nn.softmax(logits)
y_pred = tf.clip_by_value(probs, epsilon, 1. - epsilon)
# weight term and alpha term【因為y_true是只有1個元素為1其他元素為0的one-hot向量,所以對於每個樣本,只有y_true位置為1的對應類別才有weight,其他都是0】這也是為什么網上有的版本會用到tf.gather函數,這個函數的作用就是只把有用的這個數取出來,可以省略一些0相關的運算。
weight = tf.multiply(y_true, tf.pow(tf.subtract(1., y_pred), gamma))
if alpha != 0.0: # 我這實現中的alpha只是起到了調節loss倍數的作用(調節倍數對訓練沒影響,因為loss的梯度才是影響訓練的關鍵),要想起到調節類別不均衡的作用,要替換成數組,數組長度和類別總數相同,每個元素表示對應類別的權重。另外[這篇](https://blog.csdn.net/Umi_you/article/details/80982190)博客也提到了,alpha在多分類Focal loss中沒作用,也就是只能調節整體loss倍數,不過如果換成數組形式的話,其實是可以達到緩解類別不均衡問題的目的。
alpha_t = y_true * alpha + (tf.ones_like(y_true) - y_true) * (1 - alpha)
else:
alpha_t = tf.ones_like(y_true)
# origin x ent,這里計算原始的交叉熵損失
xent = tf.multiply(y_true, -tf.log(y_pred))
# focal x ent,對交叉熵損失進行調節,“-”號放在上一行代碼了,所以這里不需要再寫“-”了。
focal_xent = tf.multiply(alpha_t, tf.multiply(weight, xent))
# in this situation, reduce_max is equal to reduce_sum,因為經過y_true選擇后,每個樣本只保留了true label對應的交叉熵損失,所以使用max和使用sum是同等作用的。
reduced_fl = tf.reduce_max(focal_xent, axis=1)
return tf.reduce_mean(reduced_fl)
參考:
- Pytorch中的Focal Loss實現
- Pytorch官方實現的softmax_focal_loss
- Pytorch官方實現的sigmoid_focal_loss
- 何愷明大神的「Focal Loss」,如何更好地理解?,蘇劍林,2017-12
- https://github.com/artemmavrin/focal-loss/blob/master/src/focal_loss/_binary_focal_loss.py
- https://github.com/artemmavrin/focal-loss/blob/master/src/focal_loss/_categorical_focal_loss.py
- https://github.com/zhezh/focalloss/blob/master/focalloss.py
- focal loss的tensorflow實現,chris_xy,2019-03
- Multi-class classification with focal loss for imbalanced datasets,Chengwei Zhang,2018-12
- focal loss的幾種實現版本(Keras/Tensorflow),隨煜而安,2019-05【這篇文章的多分類Focal Loss有問題,gamma=0時不等同原始交叉熵損失。】
- keras中兩種交叉熵損失函數的探討,TAURUS,2019-08
- focal loss for multi-class classification,yehaihai,2018-07【這篇文章說alpha對於多分類Focal Loss不起作用,其實取決於alpha的含義,如果只是1個標量,的確無法起到緩解類別不均衡問題的作用,但如果alpah是一個數組(每個元素表示類別的權重),其實是alpha是可以在多分類Focal Loss中起到緩解類別不均衡作用的。】