系列博客,原文在筆者所維護的github上:https://aka.ms/beginnerAI,
點擊star加星不要吝嗇,星越多筆者越努力。
第3章 損失函數
3.0 損失函數概論
3.0.1 概念
在各種材料中經常看到的中英文詞匯有:誤差,偏差,Error,Cost,Loss,損失,代價......意思都差不多,在本書中,使用“損失函數”和“Loss Function”這兩個詞匯,具體的損失函數符號用J來表示,誤差值用loss表示。
“損失”就是所有樣本的“誤差”的總和,亦即(m為樣本數):
在黑盒子的例子中,我們如果說“某個樣本的損失”是不對的,只能說“某個樣本的誤差”,因為樣本是一個一個計算的。如果我們把神經網絡的參數調整到完全滿足獨立樣本的輸出誤差為0,通常會令其它樣本的誤差變得更大,這樣作為誤差之和的損失函數值,就會變得更大。所以,我們通常會在根據某個樣本的誤差調整權重后,計算一下整體樣本的損失函數值,來判定網絡是不是已經訓練到了可接受的狀態。
損失函數的作用
損失函數的作用,就是計算神經網絡每次迭代的前向計算結果與真實值的差距,從而指導下一步的訓練向正確的方向進行。
如何使用損失函數呢?具體步驟:
- 用隨機值初始化前向計算公式的參數;
- 代入樣本,計算輸出的預測值;
- 用損失函數計算預測值和標簽值(真實值)的誤差;
- 根據損失函數的導數,沿梯度最小方向將誤差回傳,修正前向計算公式中的各個權重值;
- goto 2, 直到損失函數值達到一個滿意的值就停止迭代。
3.0.2 機器學習常用損失函數
符號規則:a是預測值,y是樣本標簽值,J是損失函數值。
- Gold Standard Loss,又稱0-1誤差
- 絕對值損失函數
- Hinge Loss,鉸鏈/折頁損失函數或最大邊界損失函數,主要用於SVM(支持向量機)中
- Log Loss,對數損失函數,又叫交叉熵損失函數(cross entropy error)
- Squared Loss,均方差損失函數
- Exponential Loss,指數損失函數
3.0.3 損失函數圖像理解
用二維函數圖像理解單變量對損失函數的影響
圖3-1 單變量的損失函數圖
圖3-1中,縱坐標是損失函數值,橫坐標是變量。不斷地改變變量的值,會造成損失函數值的上升或下降。而梯度下降算法會讓我們沿着損失函數值下降的方向前進。
- 假設我們的初始位置在A點,\(x=x0\),損失函數值(縱坐標)較大,回傳給網絡做訓練;
- 經過一次迭代后,我們移動到了B點,\(x=x1\),損失函數值也相應減小,再次回傳重新訓練;
- 以此節奏不斷向損失函數的最低點靠近,經歷了\(x2、x3、x4、x5\);
- 直到損失值達到可接受的程度,比如\(x5\)的位置,就停止訓練。
用等高線圖理解雙變量對損失函數影響
圖3-2 雙變量的損失函數圖
圖3-2中,橫坐標是一個變量\(w\),縱坐標是另一個變量\(b\)。兩個變量的組合形成的損失函數值,在圖中對應處於等高線上的唯一的一個坐標點。\(w、b\)所有的不同的值的組合會形成一個損失函數值的矩陣,我們把矩陣中具有相同(相近)損失函數值的點連接起來,可以形成一個不規則橢圓,其圓心位置,是損失值為0的位置,也是我們要逼近的目標。
這個橢圓如同平面地圖的等高線,來表示的一個窪地,中心位置比邊緣位置要低,通過對損失函數值的計算,對損失函數的求導,會帶領我們沿着等高線形成的梯子一步步下降,無限逼近中心點。
3.0.4 神經網絡中常用的損失函數
-
均方差函數,主要用於回歸
-
交叉熵函數,主要用於分類
二者都是非負函數,極值在底部,用梯度下降法可以求解。
系列博客,原文在筆者所維護的github上:https://aka.ms/beginnerAI,
點擊star加星不要吝嗇,星越多筆者越努力。
3.1 均方差函數
MSE - Mean Square Error。
該函數就是最直觀的一個損失函數了,計算預測值和真實值之間的歐式距離。預測值和真實值越接近,兩者的均方差就越小。
均方差函數常用於線性回歸(linear regression),即函數擬合(function fitting)。公式如下:
3.1.1 工作原理
要想得到預測值a與真實值y的差距,最朴素的想法就是用\(Error=a_i-y_i\)。
對於單個樣本來說,這樣做沒問題,但是多個樣本累計時,\(a_i-y_i\)有可能有正有負,誤差求和時就會導致相互抵消,從而失去價值。所以有了絕對值差的想法,即\(Error=|a_i-y_i|\)。這看上去很簡單,並且也很理想,那為什么還要引入均方差損失函數呢?兩種損失函數的比較如表3-1所示。
表3-1 絕對值損失函數與均方差損失函數的比較
樣本標簽值 | 樣本預測值 | 絕對值損失函數 | 均方差損失函數 |
---|---|---|---|
\([1,1,1]\) | \([1,2,3]\) | \((1-1)+(2-1)+(3-1)=3\) | \((1-1)^2+(2-1)^2+(3-1)^2=5\) |
\([1,1,1]\) | \([1,3,3]\) | \((1-1)+(3-1)+(3-1)=4\) | \((1-1)^2+(3-1)^2+(3-1)^2=8\) |
\(4/3=1.33\) | \(8/5=1.6\) |
可以看到5比3已經大了很多,8比4大了一倍,而8比5也放大了某個樣本的局部損失對全局帶來的影響,用術語說,就是“對某些偏離大的樣本比較敏感”,從而引起監督訓練過程的足夠重視,以便回傳誤差。
3.1.2 實際案例
假設有一組數據如圖3-3,我們想找到一條擬合的直線。
圖3-3 平面上的樣本數據
圖3-4中,前三張顯示了一個逐漸找到最佳擬合直線的過程。
- 第一張,用均方差函數計算得到Loss=0.53;
- 第二張,直線向上平移一些,誤差計算Loss=0.16,比圖一的誤差小很多;
- 第三張,又向上平移了一些,誤差計算Loss=0.048,此后還可以繼續嘗試平移(改變b值)或者變換角度(改變w值),得到更小的損失函數值;
- 第四張,偏離了最佳位置,誤差值Loss=0.18,這種情況,算法會讓嘗試方向反向向下。
圖3-4 損失函數值與直線位置的關系
第三張圖損失函數值最小的情況。比較第二張和第四張圖,由於均方差的損失函數值都是正值,如何判斷是向上移動還是向下移動呢?
在實際的訓練過程中,是沒有必要計算損失函數值的,因為損失函數值會體現在反向傳播的過程中。我們來看看均方差函數的導數:
雖然\((a_i-y_i)^2\)永遠是正數,但是\(a_i-y_i\)卻可以是正數(直線在點下方時)或者負數(直線在點上方時),這個正數或者負數被反向傳播回到前面的計算過程中,就會引導訓練過程朝正確的方向嘗試。
在上面的例子中,我們有兩個變量,一個w,一個b,這兩個值的變化都會影響最終的損失函數值的。
我們假設該擬合直線的方程是y=2x+3,當我們固定w=2,把b值從2到4變化時,看看損失函數值的變化如圖3-5所示。
圖3-5 固定W時,b的變化造成的損失值
我們假設該擬合直線的方程是y=2x+3,當我們固定b=3,把w值從1到3變化時,看看損失函數值的變化如圖3-6所示。
圖3-6 固定b時,W的變化造成的損失值
3.1.3 損失函數的可視化
損失函數值的3D示意圖
橫坐標為W,縱坐標為b,針對每一個w和一個b的組合計算出一個損失函數值,用三維圖的高度來表示這個損失函數值。下圖中的底部並非一個平面,而是一個有些下凹的曲面,只不過曲率較小,如圖3-7。
圖3-7 W和b同時變化時的損失值形成的曲面
損失函數值的2D示意圖
在平面地圖中,我們經常會看到用等高線的方式來表示海拔高度值,下圖就是上圖在平面上的投影,即損失函數值的等高線圖,如圖3-8所示。
圖3-8 損失函數的等高線圖
如果還不能理解的話,我們用最笨的方法來畫一張圖,代碼如下:
s = 200
W = np.linspace(w-2,w+2,s)
B = np.linspace(b-2,b+2,s)
LOSS = np.zeros((s,s))
for i in range(len(W)):
for j in range(len(B)):
z = W[i] * x + B[j]
loss = CostFunction(x,y,z,m)
LOSS[i,j] = round(loss, 2)
上述代碼針對每個w和b的組合計算出了一個損失值,保留小數點后2位,放在LOSS矩陣中,如下所示:
[[4.69 4.63 4.57 ... 0.72 0.74 0.76]
[4.66 4.6 4.54 ... 0.73 0.75 0.77]
[4.62 4.56 4.5 ... 0.73 0.75 0.77]
...
[0.7 0.68 0.66 ... 4.57 4.63 4.69]
[0.69 0.67 0.65 ... 4.6 4.66 4.72]
[0.68 0.66 0.64 ... 4.63 4.69 4.75]]
然后遍歷矩陣中的損失函數值,在具有相同值的位置上繪制相同顏色的點,比如,把所有值為0.72的點繪制成紅色,把所有值為0.75的點繪制成藍色......,這樣就可以得到圖3-9。
圖3-9 用笨辦法繪制等高線圖
此圖和等高線圖的表達方式等價,但由於等高線圖比較簡明清晰,所以以后我們都使用等高線圖來說明問題。
代碼位置
ch03, Level1
系列博客,原文在筆者所維護的github上:https://aka.ms/beginnerAI,
點擊star加星不要吝嗇,星越多筆者越努力。
3.2 交叉熵損失函數
交叉熵(Cross Entropy)是Shannon信息論中一個重要概念,主要用於度量兩個概率分布間的差異性信息。在信息論中,交叉熵是表示兩個概率分布 \(p,q\) 的差異,其中 \(p\) 表示真實分布,\(q\) 表示非真實分布,那么\(H(p,q)\)就稱為交叉熵:
交叉熵可在神經網絡中作為損失函數,\(p\) 表示真實標記的分布,\(q\) 則為訓練后的模型的預測標記分布,交叉熵損失函數可以衡量 \(p\) 與 \(q\) 的相似性。
交叉熵函數常用於邏輯回歸(logistic regression),也就是分類(classification)。
3.2.1 交叉熵的由來
信息量
信息論中,信息量的表示方式:
\(x_j\):表示一個事件
\(p(x_j)\):表示\(x_j\)發生的概率
\(I(x_j)\):信息量,\(x_j\)越不可能發生時,它一旦發生后的信息量就越大
假設對於學習神經網絡原理課程,我們有三種可能的情況發生,如表3-2所示。
表3-2 三種事件的概論和信息量
事件編號 | 事件 | 概率 \(p\) | 信息量 \(I\) |
---|---|---|---|
\(x_1\) | 優秀 | \(p=0.7\) | \(I=-\ln(0.7)=0.36\) |
\(x_2\) | 及格 | \(p=0.2\) | \(I=-\ln(0.2)=1.61\) |
\(x_3\) | 不及格 | \(p=0.1\) | \(I=-\ln(0.1)=2.30\) |
WoW,某某同學不及格!好大的信息量!相比較來說,“優秀”事件的信息量反而小了很多。
熵
則上面的問題的熵是:
相對熵(KL散度)
相對熵又稱KL散度,如果我們對於同一個隨機變量 \(x\) 有兩個單獨的概率分布 \(P(x)\) 和 \(Q(x)\),我們可以使用 KL 散度(Kullback-Leibler (KL) divergence)來衡量這兩個分布的差異,這個相當於信息論范疇的均方差。
KL散度的計算公式:
\(n\) 為事件的所有可能性。\(D\) 的值越小,表示 \(q\) 分布和 \(p\) 分布越接近。
交叉熵
把上述公式變形:
等式的前一部分恰巧就是p的熵,等式的后一部分,就是交叉熵:
在機器學習中,我們需要評估label和predicts之間的差距,使用KL散度剛剛好,即\(D_{KL}(y||a)\),由於KL散度中的前一部分\(H(y)\)不變,故在優化過程中,只需要關注交叉熵就可以了。所以一般在機器學習中直接用交叉熵做損失函數來評估模型。
其中,\(n\) 並不是樣本個數,而是分類個數。所以,對於批量樣本的交叉熵計算公式是:
\(m\) 是樣本數,\(n\) 是分類數。
有一類特殊問題,就是事件只有兩種情況發生的可能,比如“學會了”和“沒學會”,稱為\(0/1\)分布或二分類。對於這類問題,由於\(n=2\),所以交叉熵可以簡化為:
二分類對於批量樣本的交叉熵計算公式是:
3.2.2 二分類問題交叉熵
把公式10分解開兩種情況,當\(y=1\)時,即標簽值是1,是個正例,加號后面的項為0:
橫坐標是預測輸出,縱坐標是損失函數值。y=1意味着當前樣本標簽值是1,當預測輸出越接近1時,損失函數值越小,訓練結果越准確。當預測輸出越接近0時,損失函數值越大,訓練結果越糟糕。
當y=0時,即標簽值是0,是個反例,加號前面的項為0:
此時,損失函數值如圖3-10。
圖3-10 二分類交叉熵損失函數圖
假設學會了課程的標簽值為1,沒有學會的標簽值為0。我們想建立一個預測器,對於一個特定的學員,根據出勤率、課堂表現、作業情況、學習能力等等來預測其學會課程的概率。
對於學員甲,預測其學會的概率為0.6,而實際上該學員通過了考試,真實值為1。所以,學員甲的交叉熵損失函數值是:
對於學員乙,預測其學會的概率為0.7,而實際上該學員也通過了考試。所以,學員乙的交叉熵損失函數值是:
由於0.7比0.6更接近1,是相對准確的值,所以 \(loss2\) 要比 \(loss1\) 小,反向傳播的力度也會小。
3.2.3 多分類問題交叉熵
當標簽值不是非0即1的情況時,就是多分類了。假設期末考試有三種情況:
- 優秀,標簽值OneHot編碼為\([1,0,0]\)
- 及格,標簽值OneHot編碼為\([0,1,0]\)
- 不及格,標簽值OneHot編碼為\([0,0,1]\)
假設我們預測學員丙的成績為優秀、及格、不及格的概率為:\([0.2,0.5,0.3]\),而真實情況是該學員不及格,則得到的交叉熵是:
假設我們預測學員丁的成績為優秀、及格、不及格的概率為:\([0.2,0.2,0.6]\),而真實情況是該學員不及格,則得到的交叉熵是:
可以看到,0.51比1.2的損失值小很多,這說明預測值越接近真實標簽值(0.6 vs 0.3),交叉熵損失函數值越小,反向傳播的力度越小。
3.2.4 為什么不能使用均方差做為分類問題的損失函數?
-
回歸問題通常用均方差損失函數,可以保證損失函數是個凸函數,即可以得到最優解。而分類問題如果用均方差的話,損失函數的表現不是凸函數,就很難得到最優解。而交叉熵函數可以保證區間內單調。
-
分類問題的最后一層網絡,需要分類函數,Sigmoid或者Softmax,如果再接均方差函數的話,其求導結果復雜,運算量比較大。用交叉熵函數的話,可以得到比較簡單的計算結果,一個簡單的減法就可以得到反向誤差。