多分類問題的交叉熵
在多分類問題中,損失函數(loss function)為交叉熵(cross entropy)損失函數。對於樣本點(x,y)來說,y是真實的標簽,在多分類問題中,其取值只可能為標簽集合labels. 我們假設有K個標簽值,且第i個樣本預測為第k個標簽值的概率為\(p_{i,k}\), 即\(p_{i,k} = \operatorname{Pr}(t_{i,k} = 1)\), 一共有N個樣本,則該數據集的損失函數為
一個例子
在Python的sklearn模塊中,提供了一個函數log_loss()來計算多分類問題的交叉熵。再根據我們在博客Sklearn中二分類問題的交叉熵計算對log_loss()函數的源代碼的分析,我們不難利用上面的計算公式用自己的方法來實現交叉熵的求值。
我們給出的例子如下:
y_true = ['1', '4', '5'] # 樣本的真實標簽
y_pred = [[0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0],
[0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0],
[0.6, 0.3, 0, 0, 0, 0.1, 0, 0, 0, 0]
] # 樣本的預測概率
labels = ['0','1','2','3','4','5','6','7','8','9'] # 所有標簽
在這個例子中,一個有3個樣本,標簽為1,4,5,一共是10個標簽,y_pred是對每個樣本的所有標簽的預測值。
接下來我們將會用log_loss()函數和自己的方法分別來實現這個例子的交叉熵的計算,完整的Python代碼如下:
from sklearn.metrics import log_loss
from sklearn.preprocessing import LabelBinarizer
from math import log
y_true = ['1', '4', '5'] # 樣本的真實標簽
y_pred = [[0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0],
[0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0],
[0.6, 0.3, 0, 0, 0, 0.1, 0, 0, 0, 0]
] # 樣本的預測概率
labels = ['0','1','2','3','4','5','6','7','8','9'] # 所有標簽
# 利用sklearn中的log_loss()函數計算交叉熵
sk_log_loss = log_loss(y_true, y_pred, labels=labels)
print("Loss by sklearn is:%s." %sk_log_loss)
# 利用公式實現交叉熵
# 交叉熵的計算公式網址為:
# http://scikit-learn.org/stable/modules/model_evaluation.html#log-loss
# 對樣本的真實標簽進行標簽二值化
lb = LabelBinarizer()
lb.fit(labels)
transformed_labels = lb.transform(y_true)
# print(transformed_labels)
N = len(y_true) # 樣本個數
K = len(labels) # 標簽個數
eps = 1e-15 # 預測概率的控制值
Loss = 0 # 損失值初始化
for i in range(N):
for k in range(K):
# 控制預測概率在[eps, 1-eps]內,避免求對數時出現問題
if y_pred[i][k] < eps:
y_pred[i][k] = eps
if y_pred[i][k] > 1-eps:
y_pred[i][k] = 1-eps
# 多分類問題的交叉熵計算公式
Loss -= transformed_labels[i][k]*log(y_pred[i][k])
Loss /= N
print("Loss by equation is:%s." % Loss)
輸出的結果如下:
Loss by sklearn is:1.16885263244.
Loss by equation is:1.16885263244.
這說明我們能夠用公式來自己實現交叉熵的計算了,是不是很神奇呢?
多分類問題的交叉熵計算是建立在二分類問題的交叉熵計算的基礎上,有了我們對log_loss()函數的源代碼的研究,那就用自己的方法來實現多(二)分類問題的交叉熵計算就不是問題了~~
本次分享到此結束,歡迎大家交流~~
注意:本人現已開通兩個微信公眾號: 因為Python(微信號為:python_math)以及輕松學會Python爬蟲(微信號為:easy_web_scrape), 歡迎大家關注哦~~