轉自:https://blog.csdn.net/Umi_you/article/details/80982190
Focal loss 出自何愷明團隊Focal Loss for Dense Object Detection一文,用於解決分類問題中數據類別不平衡以及判別難易程度差別的問題。文章中因用於目標檢測區分前景和背景的二分類問題,公式以二分類問題為例。項目需要,解決Focal loss在多分類上的實現,用此博客以記錄過程中的疑惑、細節和個人理解,Keras實現代碼鏈接放在最后。
框架:Keras(tensorflow后端)
環境:ubuntu16.04 python3.5
二分類和多分類
從初學開始就一直難以分清二分類和多分類在loss上的區別,雖然明白二分類其實是多分類的一個特殊情況,但在看Focal loss文章中的公式的時候還是不免頭暈,之前不願處理的細節如今不得不仔細從很基礎的地方開始解讀。
多分類Cross Entropy:
二分類Cross Entropy:
可以看出二分類問題的交叉熵其實是多分類擴展后的變形,在FocalLoss文章中,作者用一個分段函數代表二分類問題的CE(CrossEntropy)以及用pt的一個分段函數來代表二分類中標簽值為1的
部分(此處的標簽值為one-hot[0 1]或[1 0]中1所在的類別):
文章圖中的p(predict或probility?)等價於多分類Cross Entropy公式的y,也即經激活函數(多分類為softmax函數,二分類為sigmoid函數)后得到的概率,而文章中的y對應的是Cross Entropy中的 ,即label。
CE經分段函數pt作為自變量后可以轉化為 ,實際上 所代表的就是多分類CE中的 (標簽值)為1對應的 的值,只不過在二分類中 和 互斥(兩者之和為1),所以可以用一個分段的變量 來表示在i取不同值情況下的 ,我理解 為當前樣本的置信度, 越大置信度越大,交叉熵越小。總結:多分類中每個樣本的pt為one-hot中label為1的index對應預測結果pred的值,用代碼表達就是
了解
所代表的是什么之后,接下來多分類的Focal Loss就好解決了。接下來舉個三分類的例子來模擬一下流程大致就知道代碼怎么寫了:
假設
為softmax之后得出的結果:
為one-hot標簽:
:
:
:(注意pt可能為0,log(x)的取值不能為0,所以加上epsilon)
Fl:
可以看到3.4538..的地方本該是0才對,原因是log函數后會得到一個很小的值,而不是0,所以應該先做log再乘y_label:
原:
改:
順帶一提,在多分類中alpha參數是沒有效果的,每個樣本都乘以了同樣的權重
詳細信息可以看代碼中的注釋
代碼:Keras版本