圖解Focal Loss以及Tensorflow實現(二分類、多分類)


論文鏈接:Focal loss for dense object detection

總體上講,Focal Loss是一個緩解分類問題中類別不平衡、難易樣本不均衡的損失函數。首先看一下論文中的這張圖:

解釋:

  • 橫軸是ground truth類別對應的概率(經過sigmoid/softmax處理過的logits),縱軸是對應的loss值;
  • 藍色的線(gamma=0),就是原始交叉熵損失函數,可以明顯看出ground truth的概率越大,loss越小,符合常識;
  • 除了藍色的線,其他幾個都是Focal Loss的線,其實原始交叉熵損失函數是Focal Loss的特殊版本(gamma=0)
  • 其他幾個Focal Loss線都在藍色下邊,可以看出Focal Loss的作用就是【衰減】;
  • 從圖中可以看出,ground truth的概率越大(即容易分類的簡單樣本),衰減越厲害,也就是大大降低了簡單樣本的loss;
  • 從圖中可以看出,ground truth的概率越小(即不易分類的困難樣本),也是有衰減的,但是衰減的程度比較小;

下邊是我自己模擬的一組數據,一組固定的logits=[0+epsilon, 0.1, 0.2, ..., 0.9, 1.0-epsilon],然后假設ground truth分別是0、1、2、...、9、10的時候,gamma=0、0.5、1、2、...、8、16對應的loss。
例如第3行第1列的2.75表示,ground truth是類別2,即對應的logits是0.2,gamma=0的時候,loss=2.75(gamma=0,就是原始的多分類交叉熵)。

根據上表可以得到下邊的圖:

從上圖可以看出,隨着gamma增大,整體loss都下降了,但是logits相對越高(這個例子中最大logits=1),下降的倍數越大。從上表的最后一列也可以看出來,gamma=0和gamma=16的時候,logits=0只衰減了2倍,但是logits=1衰減了16倍。

因為論文中沒有給出比較官方的focal loss實現,所以網上focal loss有很多實現版本。有以下幾個判斷標准:

  • 當gamma為0的時候,等同於原始交叉熵損失;
  • 二分類版本需要同時考慮正負樣本的影響,多分類版本只需要考慮true label的影響,因為softmax的時候,已經考慮了其他labels;
  • 多分類版本因為每個樣本其實只需要1個值(即y_true one-hot向量中值為1的那個),所以有些實現會用tf.gather簡化計算;

二分類Focal Loss

二分類交叉熵損失函數


其中,y是ground truth 類別,p是模型預測樣本類別為1的概率(則1-p是樣本類別為0的概率)。

為了簡化公式,用pt表示概率:

所以二分類交叉熵公式就是:

為了處理類別不均衡問題,我們可以給二分類交叉熵公式加上一個alpha參數,實際應用中,alpha通常會根據逆類別頻率或者當作超參數根據交叉驗證得到:

二分類Focal Loss

上邊引入了alpha參數可以緩解類別不均衡問題,但是無法處理難易樣本不均衡問題。為了處理難易樣本不均衡的問題,可以引入一個調節因子(1-pt)gamma,例如gamma=2,則調節因子就是(1-pt)2。這個調節因子是個小於1的,所以可以起到衰減的作用,而且pt越接近1(模型置信度越高,說明樣本越簡單),衰減的越厲害。

當然,我們也可以給這個損失函數再加上alpha,在原論文的實驗中,這個會有一些提升。

二分類Focal Loss的Tensorflow實現

需要注意的地方:

  1. 要知道公式中的pt是類別對應的probs,而不是logits(logits經過sigmoid/softmax變成probs);
  2. 很多代碼中都用y_pred變量,自己要搞清楚y_pred是指logits還是probs;
  3. 二分類的p_t是要同時計算正/負樣本的,這里和多分類有區別;

下邊的代碼參考了這里【p.s. 這篇文章的多分類Focal Loss可能有問題?gamma=0時不等同原始交叉熵損失。】,但是也做了些調整。

def binary_focal_loss(gamma=2, alpha=0.25):
    alpha = tf.constant(alpha, dtype=tf.float32)
    gamma = tf.constant(gamma, dtype=tf.float32)
    def binary_focal_loss_fixed(n_classes, logits, true_label):
        epsilon = 1.e-8
        # 得到y_true和y_pred
        y_true = tf.one_hot(true_label, n_classes)
        probs = tf.nn.sigmoid(logits)
        y_pred = tf.clip_by_value(probs, epsilon, 1. - epsilon)
        # 得到調節因子weight和alpha
        ## 先得到y_true和1-y_true的概率【這里是正負樣本的概率都要計算哦!】
        p_t = y_true * y_pred \
              + (tf.ones_like(y_true) - y_true) * (tf.ones_like(y_true) - y_pred)
        ## 然后通過p_t和gamma得到weight
        weight = tf.pow((tf.ones_like(y_true) - p_t), gamma)
        ## 再得到alpha,y_true的是alpha,那么1-y_true的是1-alpha
        alpha_t = y_true * alpha + (tf.ones_like(y_true) - y_true) * (1 - alpha)
        # 最后就是論文中的公式,相當於:- alpha * (1-p_t)^gamma * log(p_t)
        focal_loss = - alpha_t * weight * tf.log(p_t)
        return tf.reduce_mean(focal_loss)

多分類Focal Loss

多分類交叉熵損失函數

首先看一下多分類的交叉熵損失函數:

其中y_i為第i個類別對應的真實標簽(一個one-hot向量,只有第i個位置為1),f_i(x)為對應的模型輸出值,也就是p_t,也就是經過softmax處理過的logits。直觀的解釋就是:對於每個樣本,從p_t數組中選擇第i個數取對數,再乘-1,就是這個樣本的loss了,所以y_i one-hot向量就是起一個選擇的作用,為1,即選擇,為0,即不選。

多分類Focal Loss

從公式上看,多分類Focal Loss和二分類Focal Loss沒啥區別,也是加上一個調節因子weight=(1-pt)^gamma和alpha。

多分類Focal Loss的Tensorflow實現

首先看一下多分類交叉熵損失函數的實現

def test_softmax_cross_entropy_with_logits(n_classes, logits, true_label):
    epsilon = 1.e-8
    # 得到y_true和y_pred
    y_true = tf.one_hot(true_label, n_classes)
    softmax_prob = tf.nn.softmax(logits)
    y_pred = tf.clip_by_value(softmax_prob, epsilon, 1. - epsilon)
    # 得到交叉熵,其中的“-”符號可以放在好幾個地方,都是等效的,最后取mean是為了兼容batch訓練的情況。
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_true*tf.log(y_pred)))
    return cross_entropy

所以需要做的就是往上邊這段代碼中加入gamma和alpha參數:

def test_softmax_focal_ce_3(n_classes, gamma, alpha, logits, label):
    epsilon = 1.e-8
    # y_true and y_pred
    y_true = tf.one_hot(label, n_classes)
    probs = tf.nn.softmax(logits)
    y_pred = tf.clip_by_value(probs, epsilon, 1. - epsilon)

    # weight term and alpha term【因為y_true是只有1個元素為1其他元素為0的one-hot向量,所以對於每個樣本,只有y_true位置為1的對應類別才有weight,其他都是0】這也是為什么網上有的版本會用到tf.gather函數,這個函數的作用就是只把有用的這個數取出來,可以省略一些0相關的運算。
    weight = tf.multiply(y_true, tf.pow(tf.subtract(1., y_pred), gamma))
    if alpha != 0.0:  # 我這實現中的alpha只是起到了調節loss倍數的作用(調節倍數對訓練沒影響,因為loss的梯度才是影響訓練的關鍵),要想起到調節類別不均衡的作用,要替換成數組,數組長度和類別總數相同,每個元素表示對應類別的權重。另外[這篇](https://blog.csdn.net/Umi_you/article/details/80982190)博客也提到了,alpha在多分類Focal loss中沒作用,也就是只能調節整體loss倍數,不過如果換成數組形式的話,其實是可以達到緩解類別不均衡問題的目的。
        alpha_t = y_true * alpha + (tf.ones_like(y_true) - y_true) * (1 - alpha)
    else:
        alpha_t = tf.ones_like(y_true)

    # origin x ent,這里計算原始的交叉熵損失
    xent = tf.multiply(y_true, -tf.log(y_pred))

    # focal x ent,對交叉熵損失進行調節,“-”號放在上一行代碼了,所以這里不需要再寫“-”了。
    focal_xent = tf.multiply(alpha_t, tf.multiply(weight, xent))

    # in this situation, reduce_max is equal to reduce_sum,因為經過y_true選擇后,每個樣本只保留了true label對應的交叉熵損失,所以使用max和使用sum是同等作用的。
    reduced_fl = tf.reduce_max(focal_xent, axis=1)
    return tf.reduce_mean(reduced_fl)

參考:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM