標簽平滑(Label Smoothing)詳解


什么是label smoothing?

標簽平滑(Label smoothing),像L1、L2和dropout一樣,是機器學習領域的一種正則化方法,通常用於分類問題,目的是防止模型在訓練時過於自信地預測標簽,改善泛化能力差的問題。

為什么需要label smoothing?

對於分類問題,我們通常認為訓練數據中標簽向量的目標類別概率應為1,非目標類別概率應為0。傳統的one-hot編碼的標簽向量\(y_i\)為,

\[y_i=\begin{cases}1,\quad i=target\\ 0,\quad i\ne target \end{cases}\\ \]

在訓練網絡時,最小化損失函數\(H(y,p)=-\sum\limits_i^K{y_ilogp_i}\),其中\(p_i\)由對模型倒數第二層輸出的logits向量z應用Softmax函數計算得到,

\[p_i=\dfrac{\exp(z_i)}{\sum_j^K\exp(z_j)} \]

傳統one-hot編碼標簽的網絡學習過程中,鼓勵模型預測為目標類別的概率趨近1,非目標類別的概率趨近0,即最終預測的logits向量(logits向量經過softmax后輸出的就是預測的所有類別的概率分布)中目標類別\(z_i\)的值會趨於無窮大,使得模型向預測正確與錯誤標簽的logit差值無限增大的方向學習,而過大的logit差值會使模型缺乏適應性,對它的預測過於自信。在訓練數據不足以覆蓋所有情況下,這就會導致網絡過擬合,泛化能力差,而且實際上有些標注數據不一定准確,這時候使用交叉熵損失函數作為目標函數也不一定是最優的了。

label smoothing的數學定義

label smoothing結合了均勻分布,用更新的標簽向量\(\hat y_{i}\)來替換傳統的ont-hot編碼的標簽向量\(y_{hot}\):

\[\hat y_{i}=y_{hot}(1-\alpha)+\alpha/K \\ \]

其中K為多分類的類別總個數,\(\alpha\)是一個較小的超參數(一般取0.1),即

\[\hat y_i=\begin{cases}1-\alpha,\quad i=target\\ \alpha/K,\quad i\ne target \end{cases}\\ \]

這樣,標簽平滑后的分布就相當於往真實分布中加入了噪聲,避免模型對於正確標簽過於自信,使得預測正負樣本的輸出值差別不那么大,從而避免過擬合,提高模型的泛化能力。

NIPS 2019上的這篇論文When Does Label Smoothing Help?用實驗說明了為什么Label smoothing可以work,指出標簽平滑可以讓分類之間的cluster更加緊湊,增加類間距離,減少類內距離,提高泛化性,同時還能提高Model Calibration(模型對於預測值的confidences和accuracies之間aligned的程度)。但是在模型蒸餾中使用Label smoothing會導致性能下降。

具體label smoothing的實現代碼可以參看OpenNMT的pytorch實現https://github.com/OpenNMT/OpenNMT-py/blob/e8622eb5c6117269bb3accd8eb6f66282b5e67d9/onmt/utils/loss.py#L186


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM