出發點
對於一個樣本,有輸入和輸出結果,我們的目的是優化訓練我們的模型,使得對於樣本輸入,模型的預測輸出盡可能的接近真實輸出結果。現在需要一個損失函數來評估預測輸出與真實結果的差距。
均方誤差
回歸問題
樣本有若干維,每一維都有一個真實值。我們要將樣本的數據通過我們的模型預測也得到同樣多的預測值,真實值可以看成一個向量,預測值也一樣。預測值向量要在某種定義下與真實值向量是接近的。
定義
其中\(N\)為樣本的總維數,\(y_i\)表示第i維的真實值,\(\hat y_i\)表示第i維的預測值,這個誤差函數是容易理解的。
如果把這個樣本看做N維空間中的一個向量,均方誤差實際上是這真實值與預測值兩個向量的歐氏距離
均方誤差實際上就是一種衡量“有多近”的標准,這個距離的定義顯然是合適的。
在實際應用中,我們需要利用梯度方法訓練模型,因此損失函數應當是容易計算梯度並且不會產生梯度消失的。
考慮\(L\)對每個\(\hat y_i\)的偏導
當預測值與真實值差別越大,即\(|\hat y_i-y_i|\)越大時,梯度的絕對值也是更大的,這符合我們的要求。
分類問題
與回歸問題不同的是,樣本的每一維的真實值不是連續,有序的,而是一個個離散的類別。
這些類別不僅不連續(離散),而且還是無序的,如果仍然用回歸問題的思路,直接去這個\(\hat y_i\)表示把第\(i\)維歸入哪一類顯然是不合適的,因為不存在2.5類這樣的東西。在這種情況下傳統的均方誤差也不適用了。因為第1類與第2類的差距並不一定比第1類與第9類的差距小,然而\((2-1)^2<(9-1)^2\),也就是說類之間沒法定義“距離”的概念了。
如果我們換一種思路呢,嘗試預測分到每一類的概率?
概率分布
假設總共有\(K\)類,對於每一維的預測值不是一個類別的確定值,而是一個K維的向量,代表分到每一類的一個概率分布。
設\(z_i^k\)表示樣本第\(i\)維被分到第k類的概率
神經網絡直接得到出來的預測值並不能滿足概率的要求——和為1
那么將這個預測值過一個Softmax
即
這樣就得到了一個和為1的概率分布,同時這個概率分布很好的突出了最大值(指數關系)。
舉例,[1,5,3],通過softmax后得到的是[0.015,0.866,0.117]
那對應的真實值又是什么呢?
如果訓練數據中樣本每一維的Label就是確定的類別,那么就是一個只有正確的類那一維概率為1,其他維都為0的K維向量。
設第i維的真實類別為\(y_i\),那么
在某些特殊情況中,訓練數據的Label不是確定的類別,而也是一個概率分布。
既然是兩個向量的差異,那這樣我們不可以直接做均方差嗎?
(同樣用\(\hat p_i^k\)代表預測值)
在實際情況中,分類問題往往不像回歸問題那樣要同時考慮多維,分類問題的樣本往往就是一維的(N=1),就是這個東西要分到哪一類,其實也就是加起來除以N的區別,不妨設N=1,那么下標\(i\)可以去掉了。
得到
我們下面將說明,在這里用均方差是不合適的。
梯度消失
由於這里的\(p\)是\(z\)過了一遍softmax的結果,我們還要把這個加上
就是先前求過的均方差的梯度
設
則
所以
我們前面說過,softmax的特點是突出大值,最大的那個\(\hat p_k\)大了,其他的就自然小了,\(\hat p_k\)不僅與\(\hat z_k\)有關,更重要的是\(\hat z_k\)在所有\(\hat z\)中的相對大小而不是絕對大小,所以我們主要關心真實概率\(p_k\)較大的那些類,只需要那些類被突出出來,其他自然就小了。
大的概率值應該是怎么樣的?
真實概率\(p_k\)較大,預測出來的\(\hat p_k\)我們也希望它比較大,當\(\hat p_k\)很小的時候就錯了,這個時候梯度絕對值應該更大
這時候就出問題了,當預測出來\(\hat p_k\)特別接近0的時候,均方差計算出的梯度非常小,當\(\hat p_k=0\)時梯度直接消失了,這明顯是有問題的。這就需要我們使用新的損失函數。
交叉熵
交叉熵本身是用來計算兩個概率分布之間的差異性信息的。
定義式是這樣的
log外面的是第k類的真實概率,里面的是第k類的預測概率
(在絕大多數分類問題中,\(p_k\)只有一個是1,其他都是0,都是確定的分類任務,也就是說求和式只有一項。但不失一般性,我們還是按照原來的形式討論)
現在我們來求它的梯度
求梯度
其中
這一部分與上面是一樣的。
而
那么
真實概率\(p_k\)較大,當\(\hat p_k\)很小的時候,梯度絕對值應該更大,可以觀察到上式是符合突出大值的要求的,在這種情況下按照梯度下降法走一步,\(\hat z_k\)的增大量會比其他\(z\)要大,\(\hat p_k\)就會被突出。
另外的想法
網上還有一些說法是,在分類問題中我們只關心最大值(因為最后輸出的答案還是要找一個概率最大的輸出),把整個分布擬合的那么像沒有意義。
在確定的分類任務中,對於那些錯誤的類(\(p_k=0\),預測值\(\hat p_k\)是多少並不重要,只要它不是最大的那個就行了,所以它是\(0.1\)還是\(0.01\)並不一定有很大的差別,而均方誤差的目標是概率分布的完全擬合,它可能過於嚴格了。)