什么是label smoothing?
標簽平滑(Label smoothing),像L1、L2和dropout一樣,是機器學習領域的一種正則化方法,通常用於分類問題,目的是防止模型在訓練時過於自信地預測標簽,改善泛化能力差的問題。
為什么需要label smoothing?
對於分類問題,我們通常認為訓練數據中標簽向量的目標類別概率應為1,非目標類別概率應為0。傳統的one-hot編碼的標簽向量\(y_i\)為,
在訓練網絡時,最小化損失函數\(H(y,p)=-\sum\limits_i^K{y_ilogp_i}\),其中\(p_i\)由對模型倒數第二層輸出的logits向量z應用Softmax函數計算得到,
傳統one-hot編碼標簽的網絡學習過程中,鼓勵模型預測為目標類別的概率趨近1,非目標類別的概率趨近0,即最終預測的logits向量(logits向量經過softmax后輸出的就是預測的所有類別的概率分布)中目標類別\(z_i\)的值會趨於無窮大,使得模型向預測正確與錯誤標簽的logit差值無限增大的方向學習,而過大的logit差值會使模型缺乏適應性,對它的預測過於自信。在訓練數據不足以覆蓋所有情況下,這就會導致網絡過擬合,泛化能力差,而且實際上有些標注數據不一定准確,這時候使用交叉熵損失函數作為目標函數也不一定是最優的了。
label smoothing的數學定義
label smoothing結合了均勻分布,用更新的標簽向量\(\hat y_{i}\)來替換傳統的ont-hot編碼的標簽向量\(y_{hot}\):
其中K為多分類的類別總個數,\(\alpha\)是一個較小的超參數(一般取0.1),即
這樣,標簽平滑后的分布就相當於往真實分布中加入了噪聲,避免模型對於正確標簽過於自信,使得預測正負樣本的輸出值差別不那么大,從而避免過擬合,提高模型的泛化能力。
NIPS 2019上的這篇論文When Does Label Smoothing Help?用實驗說明了為什么Label smoothing可以work,指出標簽平滑可以讓分類之間的cluster更加緊湊,增加類間距離,減少類內距離,提高泛化性,同時還能提高Model Calibration(模型對於預測值的confidences和accuracies之間aligned的程度)。但是在模型蒸餾中使用Label smoothing會導致性能下降。
具體label smoothing的實現代碼可以參看OpenNMT的pytorch實現https://github.com/OpenNMT/OpenNMT-py/blob/e8622eb5c6117269bb3accd8eb6f66282b5e67d9/onmt/utils/loss.py#L186