目錄
產生背景
工作原理
參考資料
產生背景 |
假設選用softmax交叉熵訓練一個三分類模型,某樣本經過網絡最后一層的輸出為向量x=(1.0, 5.0, 4.0),對x進行softmax轉換輸出為:
假設該樣本y=[0, 1, 0],那損失loss:
按softmax交叉熵優化時,針對這個樣本而言,會讓0.721越來越接近於1,因為這樣會減少loss,但是這有可能造成過擬合。可以這樣理解,如果0.721已經接近於1了,那么網絡會對該樣本十分“關注”,也就是過擬合。我們可以通過標簽平滑的方式解決。
以下是論文中對此問題的闡述:
工作原理 |
假設有一批數據在神經網絡最后一層的輸出值和他們的真實標簽
out = np.array([[4.0, 5.0, 10.0], [1.0, 5.0, 4.0], [1.0, 15.0, 4.0]])
y = np.array([[0, 0, 1], [0, 1, 0], [0, 1, 0]])
直接計算softmax交叉熵損失:
res = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=out, label_smoothing=0)
print(tf.Session().run(res))
結果為:0.11191821843385696
使用標簽平滑后:
res2 = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=out, label_smoothing=0.001)
print(tf.Session().run(res2))
結果為:0.11647378653287888
可以看出,損失比之前增加了,他的標簽平滑的原理是對真實標簽做了改變,源碼里的公式為:
# new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes
new_onehot_labels = y * (1 - 0.001) + 0.001 / 3
print(y)
print(new_onehot_labels)
[[0 0 1]
[0 1 0]
[0 1 0]]
[[3.33333333e-04 3.33333333e-04 9.99333333e-01]
[3.33333333e-04 9.99333333e-01 3.33333333e-04]
[3.33333333e-04 9.99333333e-01 3.33333333e-04]]
然后使用平滑標簽計算softmax交叉熵就能得到最終的結果了,我們也可以驗證一下:
res3 = tf.losses.softmax_cross_entropy(onehot_labels=new_onehot_labels, logits=out, label_smoothing=0)
print(tf.Session().run(res3))
結果為:0.11647378653287888
完整代碼:

import numpy as np import tensorflow as tf out = np.array([[4.0, 5.0, 10.0], [1.0, 5.0, 4.0], [1.0, 15.0, 4.0]]) y = np.array([[0, 0, 1], [0, 1, 0], [0, 1, 0]]) res = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=out, label_smoothing=0) print(tf.Session().run(res)) res2 = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=out, label_smoothing=0.001) print(tf.Session().run(res2)) # new_onehot_labels = onehot_labels * (1 - label_smoothing) # + label_smoothing / num_classes new_onehot_labels = y * (1 - 0.001) + 0.001 / 3 print(y) print(new_onehot_labels) res3 = tf.losses.softmax_cross_entropy(onehot_labels=new_onehot_labels, logits=out, label_smoothing=0) print(tf.Session().run(res3))
參考資料 |
Rethinking the Inception Architecture for Computer Vision
標簽平滑(Label Smoothing)——分類問題中錯誤標注的一種解決方法
https://www.datalearner.com/blog/1051561454844661