什么是標簽平滑?在PyTorch中如何去使用它?
在訓練深度學習模型的過程中,過擬合和概率校准(probability calibration)是兩個常見的問題。一方面,正則化技術可以解決過擬合問題,其中較為常見的方法有將權重調小,迭代提前停止以及丟棄一些權重等。另一方面,Platt標度法和isotonic regression法能夠對模型進行校准。但是有沒有一種方法可以同時解決過擬合和模型過度自信呢?
標簽平滑也許可以。它是一種去改變目標變量的正則化技術,能使模型的預測結果不再僅為一個確定值。標簽平滑之所以被看作是一種正則化技術,是因為它可以防止輸入到softmax函數的最大logits值變得特別大,從而使得分類模型變得更加准確。
在這篇文章中,我們定義了標簽平滑化,在測試過程中我們將它應用到交叉熵損失函數中。
標簽平滑?
假設這里有一個多分類問題,在這個問題中,目標變量通常是一個one-hot向量,即當處於正確分類時結果為1,否則結果是0。
標簽平滑改變了目標向量的最小值,使它為ε。因此,當模型進行分類時,其結果不再僅是1或0,而是我們所要求的1-ε和ε,從而帶標簽平滑的交叉熵損失函數為如下公式。
在這個公式中,ce(x)表示x的標准交叉熵損失函數,例如:-log(p(x)),ε是一個非常小的正數,i表示對應的正確分類,N為所有分類的數量。
直觀上看,標記平滑限制了正確類的logit值,並使得它更接近於其他類的logit值。從而在一定程度上,它被當作為一種正則化技術和一種對抗模型過度自信的方法。
PyTorch中的使用
在PyTorch中,帶標簽平滑的交叉熵損失函數實現起來非常簡單。首先,讓我們使用一個輔助函數來計算兩個值之間的線性組合。
deflinear_combination(x, y, epsilon):return epsilon*x + (1-epsilon)*y
下一步,我們使用PyTorch中一個全新的損失函數:nn.Module.
查看全部文章請訪問 https://imba.deephub.ai/p/456cf010747e11ea90cd05de3860c663
或關注公眾號 deephub-imba

