神經網絡系列之三 -- 損失函數


系列博客,原文在筆者所維護的github上:https://aka.ms/beginnerAI,
點擊star加星不要吝嗇,星越多筆者越努力。

第3章 損失函數

3.0 損失函數概論

3.0.1 概念

在各種材料中經常看到的中英文詞匯有:誤差,偏差,Error,Cost,Loss,損失,代價......意思都差不多,在本書中,使用“損失函數”和“Loss Function”這兩個詞匯,具體的損失函數符號用J來表示,誤差值用loss表示。

“損失”就是所有樣本的“誤差”的總和,亦即(m為樣本數):

\[損失 = \sum^m_{i=1}誤差_i \]

\[J = \sum_{i=1}^m loss \]

在黑盒子的例子中,我們如果說“某個樣本的損失”是不對的,只能說“某個樣本的誤差”,因為樣本是一個一個計算的。如果我們把神經網絡的參數調整到完全滿足獨立樣本的輸出誤差為0,通常會令其它樣本的誤差變得更大,這樣作為誤差之和的損失函數值,就會變得更大。所以,我們通常會在根據某個樣本的誤差調整權重后,計算一下整體樣本的損失函數值,來判定網絡是不是已經訓練到了可接受的狀態。

損失函數的作用

損失函數的作用,就是計算神經網絡每次迭代的前向計算結果與真實值的差距,從而指導下一步的訓練向正確的方向進行。

如何使用損失函數呢?具體步驟:

  1. 用隨機值初始化前向計算公式的參數;
  2. 代入樣本,計算輸出的預測值;
  3. 用損失函數計算預測值和標簽值(真實值)的誤差;
  4. 根據損失函數的導數,沿梯度最小方向將誤差回傳,修正前向計算公式中的各個權重值;
  5. goto 2, 直到損失函數值達到一個滿意的值就停止迭代。

3.0.2 機器學習常用損失函數

符號規則:a是預測值,y是樣本標簽值,J是損失函數值。

  • Gold Standard Loss,又稱0-1誤差

\[loss=\begin{cases} 0 & a=y \\ 1 & a \ne y \end{cases} \]

  • 絕對值損失函數

\[loss = |y-a| \]

  • Hinge Loss,鉸鏈/折頁損失函數或最大邊界損失函數,主要用於SVM(支持向量機)中

\[loss=max(0,1-y \cdot a), y=\pm 1 \]

  • Log Loss,對數損失函數,又叫交叉熵損失函數(cross entropy error)

\[loss = -\frac{1}{m} \sum_i^m y_i log(a_i) + (1-y_i)log(1-a_i) \qquad y_i \in \{0,1\} \]

  • Squared Loss,均方差損失函數

\[loss=\frac{1}{2m} \sum_i^m (a_i-y_i)^2 \]

  • Exponential Loss,指數損失函數

\[loss = \frac{1}{m}\sum_i^m e^{-(y_i \cdot a_i)} \]

3.0.3 損失函數圖像理解

用二維函數圖像理解單變量對損失函數的影響

圖3-1 單變量的損失函數圖

圖3-1中,縱坐標是損失函數值,橫坐標是變量。不斷地改變變量的值,會造成損失函數值的上升或下降。而梯度下降算法會讓我們沿着損失函數值下降的方向前進。

  1. 假設我們的初始位置在A點,\(x=x0\),損失函數值(縱坐標)較大,回傳給網絡做訓練;
  2. 經過一次迭代后,我們移動到了B點,\(x=x1\),損失函數值也相應減小,再次回傳重新訓練;
  3. 以此節奏不斷向損失函數的最低點靠近,經歷了\(x2、x3、x4、x5\)
  4. 直到損失值達到可接受的程度,比如\(x5\)的位置,就停止訓練。

用等高線圖理解雙變量對損失函數影響

圖3-2 雙變量的損失函數圖

圖3-2中,橫坐標是一個變量\(w\),縱坐標是另一個變量\(b\)。兩個變量的組合形成的損失函數值,在圖中對應處於等高線上的唯一的一個坐標點。\(w、b\)所有的不同的值的組合會形成一個損失函數值的矩陣,我們把矩陣中具有相同(相近)損失函數值的點連接起來,可以形成一個不規則橢圓,其圓心位置,是損失值為0的位置,也是我們要逼近的目標。

這個橢圓如同平面地圖的等高線,來表示的一個窪地,中心位置比邊緣位置要低,通過對損失函數值的計算,對損失函數的求導,會帶領我們沿着等高線形成的梯子一步步下降,無限逼近中心點。

3.0.4 神經網絡中常用的損失函數

  • 均方差函數,主要用於回歸

  • 交叉熵函數,主要用於分類

二者都是非負函數,極值在底部,用梯度下降法可以求解。

系列博客,原文在筆者所維護的github上:https://aka.ms/beginnerAI,
點擊star加星不要吝嗇,星越多筆者越努力。

3.1 均方差函數

MSE - Mean Square Error。

該函數就是最直觀的一個損失函數了,計算預測值和真實值之間的歐式距離。預測值和真實值越接近,兩者的均方差就越小。

均方差函數常用於線性回歸(linear regression),即函數擬合(function fitting)。公式如下:

\[loss = {1 \over 2}(z-y)^2 \tag{單樣本} \]

\[J=\frac{1}{2m} \sum_{i=1}^m (z_i-y_i)^2 \tag{多樣本} \]

3.1.1 工作原理

要想得到預測值a與真實值y的差距,最朴素的想法就是用\(Error=a_i-y_i\)

對於單個樣本來說,這樣做沒問題,但是多個樣本累計時,\(a_i-y_i\)有可能有正有負,誤差求和時就會導致相互抵消,從而失去價值。所以有了絕對值差的想法,即\(Error=|a_i-y_i|\)。這看上去很簡單,並且也很理想,那為什么還要引入均方差損失函數呢?兩種損失函數的比較如表3-1所示。

表3-1 絕對值損失函數與均方差損失函數的比較

樣本標簽值 樣本預測值 絕對值損失函數 均方差損失函數
\([1,1,1]\) \([1,2,3]\) \((1-1)+(2-1)+(3-1)=3\) \((1-1)^2+(2-1)^2+(3-1)^2=5\)
\([1,1,1]\) \([1,3,3]\) \((1-1)+(3-1)+(3-1)=4\) \((1-1)^2+(3-1)^2+(3-1)^2=8\)
\(4/3=1.33\) \(8/5=1.6\)

可以看到5比3已經大了很多,8比4大了一倍,而8比5也放大了某個樣本的局部損失對全局帶來的影響,用術語說,就是“對某些偏離大的樣本比較敏感”,從而引起監督訓練過程的足夠重視,以便回傳誤差。

3.1.2 實際案例

假設有一組數據如圖3-3,我們想找到一條擬合的直線。

圖3-3 平面上的樣本數據

圖3-4中,前三張顯示了一個逐漸找到最佳擬合直線的過程。

  • 第一張,用均方差函數計算得到Loss=0.53;
  • 第二張,直線向上平移一些,誤差計算Loss=0.16,比圖一的誤差小很多;
  • 第三張,又向上平移了一些,誤差計算Loss=0.048,此后還可以繼續嘗試平移(改變b值)或者變換角度(改變w值),得到更小的損失函數值;
  • 第四張,偏離了最佳位置,誤差值Loss=0.18,這種情況,算法會讓嘗試方向反向向下。

圖3-4 損失函數值與直線位置的關系

第三張圖損失函數值最小的情況。比較第二張和第四張圖,由於均方差的損失函數值都是正值,如何判斷是向上移動還是向下移動呢?

在實際的訓練過程中,是沒有必要計算損失函數值的,因為損失函數值會體現在反向傳播的過程中。我們來看看均方差函數的導數:

\[\frac{\partial{J}}{\partial{a_i}} = a_i-y_i \]

雖然\((a_i-y_i)^2\)永遠是正數,但是\(a_i-y_i\)卻可以是正數(直線在點下方時)或者負數(直線在點上方時),這個正數或者負數被反向傳播回到前面的計算過程中,就會引導訓練過程朝正確的方向嘗試。

在上面的例子中,我們有兩個變量,一個w,一個b,這兩個值的變化都會影響最終的損失函數值的。

我們假設該擬合直線的方程是y=2x+3,當我們固定w=2,把b值從2到4變化時,看看損失函數值的變化如圖3-5所示。

圖3-5 固定W時,b的變化造成的損失值

我們假設該擬合直線的方程是y=2x+3,當我們固定b=3,把w值從1到3變化時,看看損失函數值的變化如圖3-6所示。

圖3-6 固定b時,W的變化造成的損失值

3.1.3 損失函數的可視化

損失函數值的3D示意圖

橫坐標為W,縱坐標為b,針對每一個w和一個b的組合計算出一個損失函數值,用三維圖的高度來表示這個損失函數值。下圖中的底部並非一個平面,而是一個有些下凹的曲面,只不過曲率較小,如圖3-7。

圖3-7 W和b同時變化時的損失值形成的曲面

損失函數值的2D示意圖

在平面地圖中,我們經常會看到用等高線的方式來表示海拔高度值,下圖就是上圖在平面上的投影,即損失函數值的等高線圖,如圖3-8所示。

圖3-8 損失函數的等高線圖

如果還不能理解的話,我們用最笨的方法來畫一張圖,代碼如下:

    s = 200
    W = np.linspace(w-2,w+2,s)
    B = np.linspace(b-2,b+2,s)
    LOSS = np.zeros((s,s))
    for i in range(len(W)):
        for j in range(len(B)):
            z = W[i] * x + B[j]
            loss = CostFunction(x,y,z,m)
            LOSS[i,j] = round(loss, 2)

上述代碼針對每個w和b的組合計算出了一個損失值,保留小數點后2位,放在LOSS矩陣中,如下所示:

[[4.69 4.63 4.57 ... 0.72 0.74 0.76]
 [4.66 4.6  4.54 ... 0.73 0.75 0.77]
 [4.62 4.56 4.5  ... 0.73 0.75 0.77]
 ...
 [0.7  0.68 0.66 ... 4.57 4.63 4.69]
 [0.69 0.67 0.65 ... 4.6  4.66 4.72]
 [0.68 0.66 0.64 ... 4.63 4.69 4.75]]

然后遍歷矩陣中的損失函數值,在具有相同值的位置上繪制相同顏色的點,比如,把所有值為0.72的點繪制成紅色,把所有值為0.75的點繪制成藍色......,這樣就可以得到圖3-9。

圖3-9 用笨辦法繪制等高線圖

此圖和等高線圖的表達方式等價,但由於等高線圖比較簡明清晰,所以以后我們都使用等高線圖來說明問題。

代碼位置

ch03, Level1

系列博客,原文在筆者所維護的github上:https://aka.ms/beginnerAI,
點擊star加星不要吝嗇,星越多筆者越努力。

3.2 交叉熵損失函數

交叉熵(Cross Entropy)是Shannon信息論中一個重要概念,主要用於度量兩個概率分布間的差異性信息。在信息論中,交叉熵是表示兩個概率分布 \(p,q\) 的差異,其中 \(p\) 表示真實分布,\(q\) 表示非真實分布,那么\(H(p,q)\)就稱為交叉熵:

\[H(p,q)=\sum_i p_i \cdot \ln {1 \over q_i} = - \sum_i p_i \ln q_i \tag{1} \]

交叉熵可在神經網絡中作為損失函數,\(p\) 表示真實標記的分布,\(q\) 則為訓練后的模型的預測標記分布,交叉熵損失函數可以衡量 \(p\)\(q\) 的相似性。

交叉熵函數常用於邏輯回歸(logistic regression),也就是分類(classification)。

3.2.1 交叉熵的由來

信息量

信息論中,信息量的表示方式:

\[I(x_j) = -\ln (p(x_j)) \tag{2} \]

\(x_j\):表示一個事件

\(p(x_j)\):表示\(x_j\)發生的概率

\(I(x_j)\):信息量,\(x_j\)越不可能發生時,它一旦發生后的信息量就越大

假設對於學習神經網絡原理課程,我們有三種可能的情況發生,如表3-2所示。

表3-2 三種事件的概論和信息量

事件編號 事件 概率 \(p\) 信息量 \(I\)
\(x_1\) 優秀 \(p=0.7\) \(I=-\ln(0.7)=0.36\)
\(x_2\) 及格 \(p=0.2\) \(I=-\ln(0.2)=1.61\)
\(x_3\) 不及格 \(p=0.1\) \(I=-\ln(0.1)=2.30\)

WoW,某某同學不及格!好大的信息量!相比較來說,“優秀”事件的信息量反而小了很多。

\[H(p) = - \sum_j^n p(x_j) \ln (p(x_j)) \tag{3} \]

則上面的問題的熵是:

\[\begin{aligned} H(p)&=-[p(x_1) \ln p(x_1) + p(x_2) \ln p(x_2) + p(x_3) \ln p(x_3)] \\ &=0.7 \times 0.36 + 0.2 \times 1.61 + 0.1 \times 2.30 \\ &=0.804 \end{aligned} \]

相對熵(KL散度)

相對熵又稱KL散度,如果我們對於同一個隨機變量 \(x\) 有兩個單獨的概率分布 \(P(x)\)\(Q(x)\),我們可以使用 KL 散度(Kullback-Leibler (KL) divergence)來衡量這兩個分布的差異,這個相當於信息論范疇的均方差。

KL散度的計算公式:

\[D_{KL}(p||q)=\sum_{j=1}^n p(x_j) \ln{p(x_j) \over q(x_j)} \tag{4} \]

\(n\) 為事件的所有可能性。\(D\) 的值越小,表示 \(q\) 分布和 \(p\) 分布越接近。

交叉熵

把上述公式變形:

\[\begin{aligned} D_{KL}(p||q)&=\sum_{j=1}^n p(x_j) \ln{p(x_j)} - \sum_{j=1}^n p(x_j) \ln q(x_j) \\ &=- H(p(x)) + H(p,q) \end{aligned} \tag{5} \]

等式的前一部分恰巧就是p的熵,等式的后一部分,就是交叉熵:

\[H(p,q) =- \sum_{j=1}^n p(x_j) \ln q(x_j) \tag{6} \]

在機器學習中,我們需要評估label和predicts之間的差距,使用KL散度剛剛好,即\(D_{KL}(y||a)\),由於KL散度中的前一部分\(H(y)\)不變,故在優化過程中,只需要關注交叉熵就可以了。所以一般在機器學習中直接用交叉熵做損失函數來評估模型。

\[loss =- \sum_{j=1}^n y_j \ln a_j \tag{7} \]

其中,\(n\) 並不是樣本個數,而是分類個數。所以,對於批量樣本的交叉熵計算公式是:

\[J =- \sum_{i=1}^m \sum_{j=1}^n y_{ij} \ln a_{ij} \tag{8} \]

\(m\) 是樣本數,\(n\) 是分類數。

有一類特殊問題,就是事件只有兩種情況發生的可能,比如“學會了”和“沒學會”,稱為\(0/1\)分布或二分類。對於這類問題,由於\(n=2\),所以交叉熵可以簡化為:

\[loss =-[y \ln a + (1-y) \ln (1-a)] \tag{9} \]

二分類對於批量樣本的交叉熵計算公式是:

\[J= - \sum_{i=1}^m [y_i \ln a_i + (1-y_i) \ln (1-a_i)] \tag{10} \]

3.2.2 二分類問題交叉熵

把公式10分解開兩種情況,當\(y=1\)時,即標簽值是1,是個正例,加號后面的項為0:

\[loss = -\ln(a) \tag{11} \]

橫坐標是預測輸出,縱坐標是損失函數值。y=1意味着當前樣本標簽值是1,當預測輸出越接近1時,損失函數值越小,訓練結果越准確。當預測輸出越接近0時,損失函數值越大,訓練結果越糟糕。

當y=0時,即標簽值是0,是個反例,加號前面的項為0:

\[loss = -\ln (1-a) \tag{12} \]

此時,損失函數值如圖3-10。

圖3-10 二分類交叉熵損失函數圖

假設學會了課程的標簽值為1,沒有學會的標簽值為0。我們想建立一個預測器,對於一個特定的學員,根據出勤率、課堂表現、作業情況、學習能力等等來預測其學會課程的概率。

對於學員甲,預測其學會的概率為0.6,而實際上該學員通過了考試,真實值為1。所以,學員甲的交叉熵損失函數值是:

\[loss_1 = -(1 \times \ln 0.6 + (1-1) \times \ln (1-0.6)) = 0.51 \]

對於學員乙,預測其學會的概率為0.7,而實際上該學員也通過了考試。所以,學員乙的交叉熵損失函數值是:

\[loss_2 = -(1 \times \ln 0.7 + (1-1) \times \ln (1-0.7)) = 0.36 \]

由於0.7比0.6更接近1,是相對准確的值,所以 \(loss2\) 要比 \(loss1\) 小,反向傳播的力度也會小。

3.2.3 多分類問題交叉熵

當標簽值不是非0即1的情況時,就是多分類了。假設期末考試有三種情況:

  1. 優秀,標簽值OneHot編碼為\([1,0,0]\)
  2. 及格,標簽值OneHot編碼為\([0,1,0]\)
  3. 不及格,標簽值OneHot編碼為\([0,0,1]\)

假設我們預測學員丙的成績為優秀、及格、不及格的概率為:\([0.2,0.5,0.3]\),而真實情況是該學員不及格,則得到的交叉熵是:

\[loss_1 = -(0 \times \ln 0.2 + 0 \times \ln 0.5 + 1 \times \ln 0.3) = 1.2 \]

假設我們預測學員丁的成績為優秀、及格、不及格的概率為:\([0.2,0.2,0.6]\),而真實情況是該學員不及格,則得到的交叉熵是:

\[loss_2 = -(0 \times \ln 0.2 + 0 \times \ln 0.2 + 1 \times \ln 0.6) = 0.51 \]

可以看到,0.51比1.2的損失值小很多,這說明預測值越接近真實標簽值(0.6 vs 0.3),交叉熵損失函數值越小,反向傳播的力度越小。

3.2.4 為什么不能使用均方差做為分類問題的損失函數?

  1. 回歸問題通常用均方差損失函數,可以保證損失函數是個凸函數,即可以得到最優解。而分類問題如果用均方差的話,損失函數的表現不是凸函數,就很難得到最優解。而交叉熵函數可以保證區間內單調。

  2. 分類問題的最后一層網絡,需要分類函數,Sigmoid或者Softmax,如果再接均方差函數的話,其求導結果復雜,運算量比較大。用交叉熵函數的話,可以得到比較簡單的計算結果,一個簡單的減法就可以得到反向誤差。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM