Pytorch里的CrossEntropyLoss詳解


在使用Pytorch時經常碰見這些函數cross_entropy,CrossEntropyLoss, log_softmax, softmax。看得我頭大,所以整理本文以備日后查閱。

首先要知道上面提到的這些函數一部分是來自於torch.nn,而另一部分則來自於torch.nn.functional(常縮寫為F)。二者函數的區別可參見 知乎:torch.nn和funtional函數區別是什么?

下面是對與cross entropy有關的函數做的總結:

torch.nn torch.nn.functional (F)
CrossEntropyLoss cross_entropy
LogSoftmax log_softmax
NLLLoss nll_loss

下面將主要介紹torch.nn.functional中的函數為主,torch.nn中對應的函數其實就是對F里的函數進行包裝以便管理變量等操作。

在介紹cross_entropy之前先介紹兩個基本函數:

log_softmax

這個很好理解,其實就是logsoftmax合並在一起執行。

nll_loss

該函數的全程是negative log likelihood loss,函數表達式為

\[f(x,class)=-x[class] \]

例如假設\(x=[1,2,3], class=2\),那額\(f(x,class)=-x[2]=-3\)

cross_entropy

交叉熵的計算公式為:

\[cross\_entropy=-\sum_{k=1}^{N}\left(p_{k} * \log q_{k}\right) \]

其中\(p\)表示真實值,在這個公式中是one-hot形式;\(q\)是預測值,在這里假設已經是經過softmax后的結果了。

仔細觀察可以知道,因為\(p\)的元素不是0就是1,而且又是乘法,所以很自然地我們如果知道1所對應的index,那么就不用做其他無意義的運算了。所以在pytorch代碼中target不是以one-hot形式表示的,而是直接用scalar表示。所以交叉熵的公式(m表示真實類別)可變形為:

\[cross\_entropy=-\sum_{k=1}^{N}\left(p_{k} * \log q_{k}\right)=-log \, q_m \]

仔細看看,是不是就是等同於log_softmaxnll_loss兩個步驟。

所以Pytorch中的F.cross_entropy會自動調用上面介紹的log_softmaxnll_loss來計算交叉熵,其計算方式如下:

\[\operatorname{loss}(x, \text {class})=-\log \left(\frac{\exp (x[\operatorname{class}])}{\sum_{j} \exp (x[j])}\right) \]

代碼示例

>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randint(5, (3,), dtype=torch.int64)
>>> loss = F.cross_entropy(input, target)
>>> loss.backward()




微信公眾號:AutoML機器學習
MARSGGBO原創
如有意合作或學術討論歡迎私戳聯系~
郵箱:marsggbo@foxmail.com

2019-2-19




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM