利用深度學習做多分類在工業或是在科研環境中都是常見的任務。在科研環境下,無論是NLP、CV或是TTS系列任務,數據都是豐富且干凈的。而在現實的工業環境中,數據問題常常成為困擾從業者的一大難題;常見的數據問題包含有:
- 數據樣本量少
- 數據缺乏標注
- 數據不干凈,存在大量的擾動
- 數據的類間樣本數量分布不均衡等等。
除此之外,還存在其他的問題,本文不逐一列舉。針對上述第4個問題,2020年7月google發表論文《 Long-Tail Learning via Logit Adjustment 》 通過 BER ( Balanced Error Rate ) 對交叉熵函數的相關推理,在原有的交叉熵的基礎上進行改造,使得平均分類精度更高。本文將簡要解讀該論文的核心推論,並使用 keras 深度學習框架進行實現,最后通過簡單的Mnist手寫數字分類的實驗驗證結果。本文將從以下四個方面進行解讀:
- 基本概念
- 核心推論
- 代碼實現
- 實驗結果
1. 基本概念
基於深度學習的多分類問題中,想要獲得更優的分類效果往往需要對數據、神經網絡的結構參數、損失函數以及訓練參數做出調整;尤其是在面對類別不均衡的數據時,做出的調整更多。在論文《 Long-Tail Learning via Logit Adjustment 》中,為了緩解類別不均衡造成的低樣本類別分類准確率低的問題,只向損失函數中加入了標簽的先驗知識便獲得了SOTA效果。因此,本文針對其核心推論,首先簡要闡述四個基本概念:(1)長尾分布、(2)softmax、(3)交叉熵、(4)BER
1.1 長尾分布
如果將所有類別的訓練數據按照每類的樣本量進行從高到低的排序,並將排序結果表現在圖形上,那么類別不均衡的訓練數據將會呈現出 “頭部” 和 “尾部” 的分布形式,如下圖所示:
樣本量多的類別形成 “頭部” ,樣本量低的類別形成 “尾部” ,類別不均衡問題很顯著。
1.2 softmax
softmax 由於其歸一化的功能以及易於求導的特點,常在二分類或多分類問題中作為神經網絡最后一層的激活函數,用於表達神經網絡的預測輸出。本文對softmax 不多做贅述,只給出其一般化的公式:
在神經網絡中,\(z_{j}\)是上一層的輸出;\(q\left(c_{j}\right)\)是本層輸出的分布形式;\(\sum_{i=1}^{n} e^{z_{i}}\)是一個 batch內\(e^{z_{i}}\)的和。
1.3 交叉熵
本文對交叉熵函數不做過多推論,詳情可查閱信息論的相關文獻。二分類或多分類問題中,通常以交叉熵函數及其變體作為損失函數進行優化,給出基本公式:
在神經網絡中,\(p\left(c_{i}\right)\)是期望的樣本分布,通常是one-hot編碼后的標簽;\(q\left(c_{i}\right)\)是神經網絡的輸出,可視作神經網絡對樣本的預測結果。
1.4 BER
BER 在二分類中為正例樣本和負例樣本中各自預測錯誤率的均值;在多分類問題中為各類樣本各自錯誤率的加權和,可以表示為以下形式(參照論文):
其中,\(f\)是整個神經網絡;\(f_{y^{\prime}}(x)\)表示輸入為\(x\),輸出為\(y^{\prime}\)的神經網絡;\(y \notin \operatorname{argmax}_{y^{\prime} \in y} f_{y^{\prime}}(x)\)表示被神經網絡錯誤識別的標簽\(y\);\(\mathbb{P}_{x \mid y}\)即為錯誤率的計算形式;\(\frac{1}{L}\)為各類權重。
2. 核心推論
按照論文思路,首先確定一個神經網絡模型:
即\(f^{*}\)為滿足BER條件的一個神經網絡模型。接着優化這一神經網絡模型\(\operatorname{argmax}_{y \in[L]} f_{y}^{*}(x)\),這一過程等價於\(\operatorname{argmax}_{y \in[L]} \mathbb{P}^{\mathrm{bal}}(y \mid x)\),即給定訓練數據\(x\)得到預測標簽\(y\),並將預測標簽\(y\)均衡化(乘上各自權重)的優化過程。簡寫為:
對於\(\mathbb{P}^{\text {bal }}(y \mid x)\),顯然\(\mathbb{P}^{\text {bal }}(y \mid x) \propto \mathbb{P}(y \mid x) / \mathbb{P}(y)\),其中\(\mathbb{P}(y)\)是標簽先驗;\(\mathbb{P}(y \mid x)\)是給定訓練數據\(x\)之后的預測標簽的條件概率。結合多分類神經網絡中訓練的實質:
依照上述過程,假設將網絡輸出logits記為s*:\(s^{*}: x \rightarrow \mathbb{R}^{L}\),由於\(s^{*}\)需要通過 softmax 激活層,即\(q\left(c_{i}\right)=\frac{e^{s^{*}}}{\sum_{i=1}^{n} e^{s^{*}}}\);因此不難得出:\(\mathbb{P}(y \mid x) \propto \exp \left(s_{y}^{*}(x)\right)\)。再結合\(\mathbb{P}^{\text {bal }}(y \mid x) \propto \mathbb{P}(y \mid x) / \mathbb{P}(y)\),可以將\(\mathbb{P}^{\text {bal }}(y \mid x)\)表示為:
參照上式,論文中給出了優化\(\mathbb{P}^{\text {bal }}(y \mid x)\)的兩種實現方式:
(1) 通過 \(\operatorname{argmax}_{y \in[L]} \exp \left(s_{y}^{*}(x)\right) / \mathbb{P}(y)\) ,在輸入 \(x\)通過所有神經網絡層得到預測predict后,除以一個先驗\(\mathbb{P}(y)\)。這種方法前人已經用過了,並且取得了一定的效果。
(2)通過 \(\operatorname{argmax}_{y \in[L]} s_{y}^{*}(x)-\ln \mathbb{P}(y)\) ,在輸入 \(x\)通過神經網絡層得到一個編碼logits后減去一個\(\ln \mathbb{P}(y)\)。論文采用的是這一種思路。
依照第二條思路,論文直接給出了一個一般化的式子,稱之為logit adjustment loss:
對比常規的softmax交叉熵:
本質上是將一個與標簽先驗有關的偏移量應用到了每一個對數輸出中(即經過softmax激活之前的結果)。
3. 代碼實現
實現的思想在於:對神經網絡的輸出logits加上一個基於先驗的偏移\(\log \left(\frac{\pi_{y^{\prime}}}{\pi_{y}}\right)^{\tau}\)。在實際中,為了在盡量有效的前提下簡便實現,取調節因子 \(\tau\)=1,\(\pi_{y^{\prime}}\)=1。則logit adjustment loss簡化為:
在keras框架下實現如下:
import keras.backend as K
def CE_with_prior(one_hot_label, logits, prior, tau=1.0):
'''
param: one_hot_label
param: logits
param: prior: real data distribution obtained by statistics
param: tau: regulator, default is 1
return: loss
'''
log_prior = K.constant(np.log(prior + 1e-8))
# align dim
for _ in range(K.ndim(logits) - 1):
log_prior = K.expand_dims(log_prior, 0)
logits = logits + tau * log_prior
loss = K.categorical_crossentropy(one_hot_label, logits, from_logits=True)
return loss
4. 實驗結果
論文《 Long-Tail Learning via Logit Adjustment 》本身對比了多種提升長尾分布分類精度方法,並使用了不同的數據集進行測試,測試表現優於現有的方法,詳細的實驗結果參照論文本身。本文為了快速驗證實現的正確性,以及該方法的有效性,使用mnist手寫數字進行了簡單的分類實驗。實驗背景如下:
\ | 詳情 |
---|---|
訓練樣本 | 0 ~ 4:5000張/類;5 ~ 9 :500張/類 |
測試樣本 | 0 ~ 9:500/類 |
運行環境 | 本地CPU |
網絡結構 | 卷積+最大池化+全連接 |
在上述背景下進行對比實驗,對比標准的多分類交叉熵和帶先驗的交叉熵分別作為loss函數下,分類網絡的表現。取相同的epoch=60,實驗結果如下:
\ | 標准多分類交叉熵 | 帶先驗的交叉熵 |
---|---|---|
准確率 | 0.9578 | 0.9720 |
訓練流程圖 | ![]() |
![]() |
PS:
我們是行者AI,我們在“AI+游戲”中不斷前行。
如果你也對游戲感興趣,對AI充滿好奇,那就快來加入我們(hr@xingzhe.ai)。