這篇寫的比較詳細:
from: https://zhuanlan.zhihu.com/p/35709485
這篇文章中,討論的Cross Entropy損失函數常用於分類問題中,但是為什么它會在分類問題中這么有效呢?我們先從一個簡單的分類例子來入手。
1. 圖像分類任務
我們希望根據圖片動物的輪廓、顏色等特征,來預測動物的類別,有三種可預測類別:貓、狗、豬。假設我們當前有兩個模型(參數不同),這兩個模型都是通過sigmoid/softmax的方式得到對於每個預測結果的概率值:
模型1:
| 預測 | 真實 | 是否正確 |
|---|---|---|
| 0.3 0.3 0.4 | 0 0 1 (豬) | 正確 |
| 0.3 0.4 0.3 | 0 1 0 (狗) | 正確 |
| 0.1 0.2 0.7 | 1 0 0 (貓) | 錯誤 |
模型1對於樣本1和樣本2以非常微弱的優勢判斷正確,對於樣本3的判斷則徹底錯誤。
模型2:
| 預測 | 真實 | 是否正確 |
|---|---|---|
| 0.1 0.2 0.7 | 0 0 1 (豬) | 正確 |
| 0.1 0.7 0.2 | 0 1 0 (狗) | 正確 |
| 0.3 0.4 0.3 | 1 0 0 (貓) | 錯誤 |
模型2對於樣本1和樣本2判斷非常准確,對於樣本3判斷錯誤,但是相對來說沒有錯得太離譜。
好了,有了模型之后,我們需要通過定義損失函數來判斷模型在樣本上的表現了,那么我們可以定義哪些損失函數呢?
1.1 Classification Error(分類錯誤率)
最為直接的損失函數定義為: ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1jbGFzc2lmaWNhdGlvbiU1QytlcnJvciUzRCU1Q2ZyYWMlN0Jjb3VudCU1QytvZiU1QytlcnJvciU1QytpdGVtcyU3RCU3QmNvdW50JTVDK29mKyU1QythbGwlNUMraXRlbXMlN0Q=.png)
模型1: ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1jbGFzc2lmaWNhdGlvbiU1QytlcnJvciUzRCU1Q2ZyYWMlN0IxJTdEJTdCMyU3RA==.png)
模型2: ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1jbGFzc2lmaWNhdGlvbiU1QytlcnJvciUzRCU1Q2ZyYWMlN0IxJTdEJTdCMyU3RA==.png)
我們知道,模型1和模型2雖然都是預測錯了1個,但是相對來說模型2表現得更好,損失函數值照理來說應該更小,但是,很遺憾的是,
並不能判斷出來,所以這種損失函數雖然好理解,但表現不太好。
1.2 Mean Squared Error (均方誤差)
均方誤差損失也是一種比較常見的損失函數,其定義為: ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1NU0UlM0QlNUNmcmFjJTdCMSU3RCU3Qm4lN0QlNUNzdW1fJTdCaSU3RCU1RW4lMjglNUNoYXQlN0J5X2klN0QteV9pJTI5JTVFMg==.png)
模型1:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrKysrJTVDdGV4dCU3QnNhbXBsZSsxK2xvc3MlM0QlN0QlMjgwLjMtMCUyOSU1RTIrJTJCKyUyODAuMy0wJTI5JTVFMislMkIrJTI4MC40LTElMjklNUUyKyUzRCswLjU0KyU1QyU1QysrKyslNUN0ZXh0JTdCc2FtcGxlKzIrbG9zcyUzRCU3RCUyODAuMy0wJTI5JTVFMislMkIrJTI4MC40LTElMjklNUUyKyUyQislMjgwLjMtMCUyOSU1RTIrJTNEKzAuNTQrJTVDJTVDKysrKyU1Q3RleHQlN0JzYW1wbGUrMytsb3NzJTNEJTdEJTI4MC4xLTElMjklNUUyKyUyQislMjgwLjItMCUyOSU1RTIrJTJCKyUyODAuNy0wJTI5JTVFMislM0QrMS4zNCslNUMlNUMrJTVDZW5kJTdCYWxpZ25lZCU3RCslNUMlNUM=.png)
對所有樣本的loss求平均:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1NU0UlM0QlNUNmcmFjJTdCMC41NCUyQjAuNTQlMkIxLjM0JTdEJTdCMyU3RCUzRDAuODErJTVDJTVD.png)
模型2:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrKyslMjYrJTVDdGV4dCU3QnNhbXBsZSsxK2xvc3MlM0QlN0QlMjgwLjEtMCUyOSU1RTIrJTJCKyUyODAuMi0wJTI5JTVFMislMkIrJTI4MC43LTElMjklNUUyKyUzRCswLjE0JTVDJTVDKysrKyUyNiU1Q3RleHQlN0JzYW1wbGUrMitsb3NzJTNEJTdEJTI4MC4xLTAlMjklNUUyKyUyQislMjgwLjctMSUyOSU1RTIrJTJCKyUyODAuMi0wJTI5JTVFMislM0QrMC4xNCU1QyU1QysrKyslMjYlNUN0ZXh0JTdCc2FtcGxlKzMrbG9zcyUzRCU3RCUyODAuMy0xJTI5JTVFMislMkIrJTI4MC40LTAlMjklNUUyKyUyQislMjgwLjMtMCUyOSU1RTIrJTNEKzAuNzQlNUMlNUMrJTVDZW5kJTdCYWxpZ25lZCU3RCslNUMlNUM=.png)
對所有樣本的loss求平均:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1NU0UlM0QlNUNmcmFjJTdCMC4xNCUyQjAuMTQlMkIwLjc0JTdEJTdCMyU3RCUzRDAuMzQrJTVDJTVD.png)
我們發現,MSE能夠判斷出來模型2優於模型1,那為什么不采樣這種損失函數呢?主要原因是在分類問題中,使用sigmoid/softmx得到概率,配合MSE損失函數時,采用梯度下降法進行學習時,會出現模型一開始訓練時,學習速率非常慢的情況(MSE損失函數)。
有了上面的直觀分析,我們可以清楚的看到,對於分類問題的損失函數來說,分類錯誤率和均方誤差損失都不是很好的損失函數,下面我們來看一下交叉熵損失函數的表現情況。
1.3 Cross Entropy Loss Function(交叉熵損失函數)
1.3.1 表達式
(1) 二分類
在二分的情況下,模型最后需要預測的結果只有兩種情況,對於每個類別我們的預測得到的概率為
和
,此時表達式為:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1MKyUzRCslNUNmcmFjJTdCMSU3RCU3Qk4lN0QlNUNzdW1fJTdCaSU3RCtMX2krJTNEKyU1Q2ZyYWMlN0IxJTdEJTdCTiU3RCU1Q3N1bV8lN0JpJTdELSU1QnlfaSU1Q2Nkb3QrbG9nJTI4cF9pJTI5KyUyQislMjgxLXlfaSUyOSU1Q2Nkb3QrbG9nJTI4MS1wX2klMjklNUQrJTVDJTVD.png)
其中:
-
—— 表示樣本
的label,正類為
,負類為 ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0w.png)
-
—— 表示樣本
預測為正類的概率
(2) 多分類
多分類的情況實際上就是對二分類的擴展:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1MKyUzRCslNUNmcmFjJTdCMSU3RCU3Qk4lN0QlNUNzdW1fJTdCaSU3RCtMX2krJTNEKy0rJTVDZnJhYyU3QjElN0QlN0JOJTdEJTVDc3VtXyU3QmklN0QrJTVDc3VtXyU3QmMlM0QxJTdEJTVFTXlfJTdCaWMlN0QlNUNsb2clMjhwXyU3QmljJTdEJTI5KyU1QyU1Qw==.png)
其中:
-
——類別的數量
-
——符號函數(
或
),如果樣本
的真實類別等於
取
,否則取 ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0w.png)
-
——觀測樣本
屬於類別
的預測概率
現在我們利用這個表達式計算上面例子中的損失函數值:
模型1:![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrKysrJTVDdGV4dCU3QnNhbXBsZSsxK2xvc3MlN0QrJTNEKy0rJTI4MCU1Q3RpbWVzK2xvZzAuMyslMkIrMCU1Q3RpbWVzK2xvZzAuMyslMkIrMSU1Q3RpbWVzK2xvZzAuNCUyOSslM0QrMC45MSslNUMlNUMrKysrJTVDdGV4dCU3QnNhbXBsZSsyK2xvc3MlN0QrJTNEKy0rJTI4MCU1Q3RpbWVzK2xvZzAuMyslMkIrMSU1Q3RpbWVzK2xvZzAuNCslMkIrMCU1Q3RpbWVzK2xvZzAuMyUyOSslM0QrMC45MSslNUMlNUMrKysrJTVDdGV4dCU3QnNhbXBsZSszK2xvc3MlN0QrJTNEKy0rJTI4MSU1Q3RpbWVzK2xvZzAuMSslMkIrMCU1Q3RpbWVzK2xvZzAuMislMkIrMCU1Q3RpbWVzK2xvZzAuNyUyOSslM0QrMi4zMCslNUMlNUMrJTVDZW5kJTdCYWxpZ25lZCU3RCslNUMlNUM=.png)
對所有樣本的loss求平均:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1MJTNEJTVDZnJhYyU3QjAuOTElMkIwLjkxJTJCMi4zJTdEJTdCMyU3RCUzRDEuMzcrJTVDJTVD.png)
模型2:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrKysrJTVDdGV4dCU3QnNhbXBsZSsxK2xvc3MlN0QrJTNEKy0rJTI4MCU1Q3RpbWVzK2xvZzAuMSslMkIrMCU1Q3RpbWVzK2xvZzAuMislMkIrMSU1Q3RpbWVzK2xvZzAuNyUyOSslM0QrMC4zNSslNUMlNUMrKysrJTVDdGV4dCU3QnNhbXBsZSsyK2xvc3MlN0QrJTNEKy0rJTI4MCU1Q3RpbWVzK2xvZzAuMSslMkIrMSU1Q3RpbWVzK2xvZzAuNyslMkIrMCU1Q3RpbWVzK2xvZzAuMiUyOSslM0QrMC4zNSslNUMlNUMrKysrJTVDdGV4dCU3QnNhbXBsZSszK2xvc3MlN0QrJTNEKy0rJTI4MSU1Q3RpbWVzK2xvZzAuMyslMkIrMCU1Q3RpbWVzK2xvZzAuNCslMkIrMCU1Q3RpbWVzK2xvZzAuNCUyOSslM0QrMS4yMCslNUMlNUMrJTVDZW5kJTdCYWxpZ25lZCU3RCslNUMlNUM=.png)
對所有樣本的loss求平均:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1MJTNEJTVDZnJhYyU3QjAuMzUlMkIwLjM1JTJCMS4yJTdEJTdCMyU3RCUzRDAuNjMrJTVDJTVD.png)
可以發現,交叉熵損失函數可以捕捉到模型1和模型2預測效果的差異。
2. 函數性質

可以看出,該函數是凸函數,求導時能夠得到全局最優值。
3. 學習過程
交叉熵損失函數經常用於分類問題中,特別是在神經網絡做分類問題時,也經常使用交叉熵作為損失函數,此外,由於交叉熵涉及到計算每個類別的概率,所以交叉熵幾乎每次都和sigmoid(或softmax)函數一起出現。
我們用神經網絡最后一層輸出的情況,來看一眼整個模型預測、獲得損失和學習的流程:
- 神經網絡最后一層得到每個類別的得分scores(也叫logits);
- 該得分經過sigmoid(或softmax)函數獲得概率輸出;
- 模型預測的類別概率輸出與真實類別的one hot形式進行交叉熵損失函數的計算。
學習任務分為二分類和多分類情況,我們分別討論這兩種情況的學習過程。
3.1 二分類情況
二分類交叉熵損失函數學習過程
如上圖所示,求導過程可分成三個子過程,即拆成三項偏導的乘積:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNmcmFjJTdCJTVDcGFydGlhbCtMX2klN0QlN0IlNUNwYXJ0aWFsK3dfaSU3RCUzRCU1Q2ZyYWMlN0IxJTdEJTdCTiU3RCU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK0xfaSU3RCU3QiU1Q3BhcnRpYWwrd19pJTdEJTNEJTVDZnJhYyU3QjElN0QlN0JOJTdEJTVDZnJhYyU3QiU1Q3BhcnRpYWwrTF9pJTdEJTdCJTVDcGFydGlhbCtwX2klN0QlNUNjZG90KyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK3BfaSU3RCU3QiU1Q3BhcnRpYWwrc19pJTdEJTVDY2RvdCslNUNmcmFjJTdCJTVDcGFydGlhbCtzX2klN0QlN0IlNUNwYXJ0aWFsK3dfaSU3RCU1QyU1Qw==.png)
3.1.1 計算第一項: ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNmcmFjJTdCJTVDcGFydGlhbCtMX2klN0QlN0IlNUNwYXJ0aWFsK3BfaSU3RA==.png)
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1MX2krJTNEKy0lNUJ5X2klNUNjZG90K2xvZyUyOHBfaSUyOSslMkIrJTI4MS15X2klMjklNUNjZG90K2xvZyUyODEtcF9pJTI5JTVEKyU1QyU1Qw==.png)
-
表示樣本
預測為正類的概率
-
為符號函數,樣本
為正類時取
,否則取 ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0w.png)
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwrTF9pJTdEJTdCJTVDcGFydGlhbCtwX2klN0QrJTI2JTNEJTVDZnJhYyU3QiU1Q3BhcnRpYWwrLSU1QnlfaSU1Q2Nkb3QrbG9nJTI4cF9pJTI5KyUyQislMjgxLXlfaSUyOSU1Q2Nkb3QrbG9nJTI4MS1wX2klMjklNUQlN0QlN0IlNUNwYXJ0aWFsK3BfaSU3RCU1QyU1QyslMjYlM0QrLSU1Q2ZyYWMlN0J5X2klN0QlN0JwX2klN0QtJTVCJTI4MS15X2klMjklNUNjZG90KyU1Q2ZyYWMlN0IxJTdEJTdCMS1wX2klN0QlNUNjZG90KyUyOC0xJTI5JTVEKyU1QyU1QysrJTI2JTNEKy0lNUNmcmFjJTdCeV9pJTdEJTdCcF9pJTdEJTJCJTVDZnJhYyU3QjEteV9pJTdEJTdCMS1wX2klN0QrJTVDJTVDKyU1Q2VuZCU3QmFsaWduZWQlN0QrJTVDJTVD.png)
3.1.2 計算第二項: ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNmcmFjJTdCJTVDcGFydGlhbCtwX2klN0QlN0IlNUNwYXJ0aWFsK3NfaSU3RCs=.png)
這一項要計算的是sigmoid函數對於score的導數,我們先回顧一下sigmoid函數和分數求導的公式:
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrKyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK3BfaSU3RCU3QiU1Q3BhcnRpYWwrc19pJTdEKyUyNiUzRCslNUNmcmFjJTdCJTI4ZSU1RSU3QnNfaSU3RCUyOSUyNyU1Q2Nkb3QrJTI4MSUyQmUlNUUlN0JzX2klN0QlMjktZSU1RSU3QnNfaSU3RCU1Q2Nkb3QrJTI4MSUyQmUlNUUlN0JzX2klN0QlMjklMjclN0QlN0IlMjgxJTJCZSU1RSU3QnNfaSU3RCUyOSU1RTIlN0QrJTVDJTVDKyslMjYlM0QrJTVDZnJhYyU3QmUlNUUlN0JzX2klN0QlNUNjZG90KyUyODElMkJlJTVFJTdCc19pJTdEJTI5LWUlNUUlN0JzX2klN0QlNUNjZG90K2UlNUUlN0JzX2klN0QlN0QlN0IlMjgxJTJCZSU1RSU3QnNfaSU3RCUyOSU1RTIlN0QrJTVDJTVDKyslMjYlM0QrJTVDZnJhYyU3QmUlNUUlN0JzX2klN0QlN0QlN0IlMjgxJTJCZSU1RSU3QnNfaSU3RCUyOSU1RTIlN0QrJTVDJTVDKyslMjYlM0QrJTVDZnJhYyU3QmUlNUUlN0JzX2klN0QlN0QlN0IxJTJCZSU1RSU3QnNfaSU3RCU3RCU1Q2Nkb3QrJTVDZnJhYyU3QjElN0QlN0IxJTJCZSU1RSU3QnNfaSU3RCU3RCslNUMlNUMrKyUyNiUzRCslNUNzaWdtYSUyOHNfaSUyOSU1Q2Nkb3QrJTVCMS0lNUNzaWdtYSUyOHNfaSUyOSU1RCslNUMlNUMrJTVDZW5kJTdCYWxpZ25lZCU3RCslNUMlNUM=.png)
3.1.3 計算第三項: ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNmcmFjJTdCJTVDcGFydGlhbCtzX2klN0QlN0IlNUNwYXJ0aWFsK3dfaSslNUMlNUMlN0Q=.png)
一般來說,scores是輸入的線性函數作用的結果,所以有:![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNmcmFjJTdCJTVDcGFydGlhbCtzX2klN0QlN0IlNUNwYXJ0aWFsK3dfaSU3RCUzRHhfaSslNUMlNUM=.png)
3.1.4 計算結果 ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNmcmFjJTdCJTVDcGFydGlhbCtMX2klN0QlN0IlNUNwYXJ0aWFsK3dfaSU3RA==.png)
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduZWQlN0QrKyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK0xfaSU3RCU3QiU1Q3BhcnRpYWwrd19pJTdEKyUyNiUzRCslNUNmcmFjJTdCJTVDcGFydGlhbCtMX2klN0QlN0IlNUNwYXJ0aWFsK3BfaSU3RCU1Q2Nkb3QrJTVDZnJhYyU3QiU1Q3BhcnRpYWwrcF9pJTdEJTdCJTVDcGFydGlhbCtzX2klN0QlNUNjZG90KyU1Q2ZyYWMlN0IlNUNwYXJ0aWFsK3NfaSU3RCU3QiU1Q3BhcnRpYWwrd19pJTdEKyU1QyU1QysrJTI2JTNEKyU1Qi0lNUNmcmFjJTdCeV9pJTdEJTdCcF9pJTdEJTJCJTVDZnJhYyU3QjEteV9pJTdEJTdCMS1wX2klN0QlNUQrJTVDY2RvdCslNUNzaWdtYSUyOHNfaSUyOSU1Q2Nkb3QrJTVCMS0lNUNzaWdtYSUyOHNfaSUyOSU1RCU1Q2Nkb3QreF9pKyU1QyU1QysrJTI2JTNEKyU1Qi0lNUNmcmFjJTdCeV9pJTdEJTdCJTVDc2lnbWElMjhzX2klMjklN0QlMkIlNUNmcmFjJTdCMS15X2klN0QlN0IxLSU1Q3NpZ21hJTI4c19pJTI5JTdEJTVEKyU1Q2Nkb3QrJTVDc2lnbWElMjhzX2klMjklNUNjZG90KyU1QjEtJTVDc2lnbWElMjhzX2klMjklNUQlNUNjZG90K3hfaSslNUMlNUMrKyUyNiUzRCslNUItJTVDZnJhYyU3QnlfaSU3RCU3QiU1Q3NpZ21hJTI4c19pJTI5JTdEJTVDY2RvdCslNUNzaWdtYSUyOHNfaSUyOSU1Q2Nkb3QrJTI4MS0lNUNzaWdtYSUyOHNfaSUyOSUyOSUyQiU1Q2ZyYWMlN0IxLXlfaSU3RCU3QjEtJTVDc2lnbWElMjhzX2klMjklN0QlNUNjZG90KyU1Q3NpZ21hJTI4c19pJTI5JTVDY2RvdCslMjgxLSU1Q3NpZ21hJTI4c19pJTI5JTI5JTVEJTVDY2RvdCt4X2krJTVDJTVDKyslMjYlM0QrJTVCLXlfaSUyQnlfaSU1Q2Nkb3QrJTVDc2lnbWElMjhzX2klMjklMkIlNUNzaWdtYSUyOHNfaSUyOS15X2klNUNjZG90KyU1Q3NpZ21hJTI4c19pJTI5JTVEJTVDY2RvdCt4X2krJTVDJTVDKyslMjYlM0QrJTVCJTVDc2lnbWElMjhzX2klMjkteV9pJTVEJTVDY2RvdCt4X2krJTVDJTVDKyU1Q2VuZCU3QmFsaWduZWQlN0QrJTVDJTVD.png)
可以看到,我們得到了一個非常漂亮的結果,所以,使用交叉熵損失函數,不僅可以很好的衡量模型的效果,又可以很容易的的進行求導計算。
3.2 多分類情況
待整理
4. 優缺點
4.1 優點
在用梯度下降法做參數更新的時候,模型學習的速度取決於兩個值:一、學習率;二、偏導值。其中,學習率是我們需要設置的超參數,所以我們重點關注偏導值。從上面的式子中,我們發現,偏導值的大小取決於
和
,我們重點關注后者,后者的大小值反映了我們模型的錯誤程度,該值越大,說明模型效果越差,但是該值越大同時也會使得偏導值越大,從而模型學習速度更快。所以,使用邏輯函數得到概率,並結合交叉熵當損失函數時,在模型效果差的時候學習速度比較快,在模型效果好的時候學習速度變慢。
4.2 缺點
Deng [4]在2019年提出了ArcFace Loss,並在論文里說了Softmax Loss的兩個缺點:1、隨着分類數目的增大,分類層的線性變化矩陣參數也隨着增大;2、對於封閉集分類問題,學習到的特征是可分離的,但對於開放集人臉識別問題,所學特征卻沒有足夠的區分性。對於人臉識別問題,首先人臉數目(對應分類數目)是很多的,而且會不斷有新的人臉進來,不是一個封閉集分類問題。
另外,sigmoid(softmax)+cross-entropy loss 擅長於學習類間的信息,因為它采用了類間競爭機制,它只關心對於正確標簽預測概率的准確性,忽略了其他非正確標簽的差異,導致學習到的特征比較散。基於這個問題的優化有很多,比如對softmax進行改進,如L-Softmax、SM-Softmax、AM-Softmax等。
5. 參考
[1]. 博客 - 神經網絡的分類模型 LOSS 函數為什么要用 CROSS ENTROPY
[2]. 博客 - Softmax as a Neural Networks Activation Function
[3]. 博客 - A Gentle Introduction to Cross-Entropy Loss Function
[4]. Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.
這篇也不錯
from: https://zhuanlan.zhihu.com/p/104130889
假設給定輸入為x,label為y,其中y的取值為0或者1,是一個分類問題。我們要訓練一個最簡單的Logistic Regression來學習一個函數f(x)使得它能較好的擬合label,如下圖所示。

其中
,
。
也即,我們要學的函數
。目標為使a(x)與label y越逼近越好。用哪種Loss來衡量這個逼近呢?我們可以回憶下交叉熵Loss和均方差Loss定義是什么:
- 最小均方誤差,MSE(Mean Squared Error)Loss
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1MXyU3Qm1zZSU3RCslM0QrJTVDZnJhYyU3QjElN0QlN0IyJTdEJTI4YSstK3klMjklNUUy.png)
- 交叉熵誤差CEE(Cross Entropy Error)Loss
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1MXyU3QmNlZSU3RCslM0QrLSUyOHklMkFsbiUyOGElMjkrJTJCKyUyODEteSUyOSUyQWxuJTI4MS1hJTI5JTI5.png)
我們想衡量模型輸出a和label y的逼近程度,其實這兩個Loss都可以。但是為什么Logistic Regression采用的是交叉熵作為損失函數呢?看下這兩個Loss function對w的導數,也就是SGD梯度下降時,w的梯度。
- 最小均方差
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNmcmFjKyU3QiU1Q3BhcnRpYWwrTF8lN0Jtc2UlN0QlN0QlN0IlNUNwYXJ0aWFsK3clN0QrJTNEKyU1Q2ZyYWMrJTdCJTVDcGFydGlhbCtMJTdEJTdCJTVDcGFydGlhbCthJTdEKyUyQSslNUNmcmFjKyU3QiU1Q3BhcnRpYWwrYSU3RCU3QiU1Q3BhcnRpYWwreiU3RCslMkErJTVDZnJhYyslN0IlNUNwYXJ0aWFsK3olN0QlN0IlNUNwYXJ0aWFsK3clN0QrJTNEKyUyOGEteSUyOSslMkErJTVDc2lnbWElNUUlN0IlMjclN0QlMjh6JTI5JTJBK3g=.png)
- 交叉熵
![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNmcmFjKyU3QiU1Q3BhcnRpYWwrTF8lN0JjZWUlN0QlN0QlN0IlNUNwYXJ0aWFsK3clN0QrJTNEKyUyOC0lNUNmcmFjKyU3QnklN0QlN0JhJTdEKyUyQislNUNmcmFjKyU3QjEteSU3RCU3QjEtYSU3RCUyOSslMkErJTVDc2lnbWElNUUlN0IlMjclN0QlMjh6JTI5JTJBK3g=.png)
由於
,則:![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNmcmFjKyU3QiU1Q3BhcnRpYWwrTF8lN0JjZWUlN0QlN0QlN0IlNUNwYXJ0aWFsK3clN0QrJTNEKyUyOGF5LXklMkJhLWF5JTI5JTJBeCslM0QrJTI4YS15JTI5JTJBeA==.png)
sigmoid函數
如下圖所示,可知的導數sigmoid
在輸出接近 0 和 1 的時候是非常小的,故導致在使用最小均方差Loss時,模型參數w會學習的非常慢。而使用交叉熵Loss則沒有這個問題。為了更快的學習速度,分類問題一般采用交叉熵損失函數。

當label = 1,也即
,交叉熵損失函數 ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD1MXyU3QmNlZSU3RCslM0QrLSUyOHklMkFsbiUyOGElMjkrJTJCKyUyODEteSUyOSUyQWxuJTI4MS1hJTI5JTI5KyUzRCstbG4lMjhhJTI5.png)
如圖所示,可知交叉熵損失函數的值域為 ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUIwJTJDJTJCJTVDaW5mdHklMjk=.png)

