為什么要用交叉熵來做損失函數:
在邏輯回歸問題中,常常使用MSE(Mean Squared Error)作為loss函數,此時:
這里的 就表示期望輸出,
表示原始的實際輸出(就是還沒有加softmax)。這里的m表示有m個樣本,loss為m個樣本的loss均值。MSE在邏輯回歸問題中比較好用,那么在分類問題中還是如此么?我們來看看Loss曲線。
將原始的實際輸出節點都經過softmax后拿出一個樣例來看,使用MSE的loss為的loss函數為:
其中 和
為常數,那么loss就可以簡化為
取c1=1,c2=2,繪制圖像:
這是一個非凸函數,只有當損失函數為凸函數時,梯度下降算法才能保證達到全局最優解。所以MSE在分類問題中,並不是一個好的loss函數。
如果利用交叉熵作為損失函數的話,那么:
還是一樣, 就表示期望輸出,
表示原始的實際輸出(就是還沒有加softmax),由於one-hot標簽的特殊性,一個1,剩下全是0,loss可以簡化為:
加入(softmax)得:
取C1=1,C2=2繪制曲線如下 :
相對MSE而言,曲線整體呈單調性,loss越大,梯度越大。便於梯度下降反向傳播,利於優化。所以一般針對分類問題采用交叉熵作為loss函數。
Pytorch中的CrossEntropyLoss()函數,計算公式如下:
ref:https://zhuanlan.zhihu.com/p/145533813
交叉熵損失函數相對MSE避免了梯度消失的一些推導:
ref:https://www.cnblogs.com/wanghui-garcia/p/10862733.html