原理
對數損失, 即對數似然損失(Log-likelihood Loss), 也稱邏輯斯諦回歸損失(Logistic Loss)或交叉熵損失(cross-entropy Loss), 是在概率估計上定義的.它常用於(multi-nominal, 多項)邏輯斯諦回歸和神經網絡,以及一些期望極大算法的變體. 可用於評估分類器的概率輸出.
對數損失通過懲罰錯誤的分類,實現對分類器的准確度(Accuracy)的量化. 最小化對數損失基本等價於最大化分類器的准確度.為了計算對數損失, 分類器必須提供對輸入的所屬的每個類別的概率值, 不只是最可能的類別. 對數損失函數的計算公式如下:
其中, Y 為輸出變量, X為輸入變量, L 為損失函數. N為輸入樣本量, M為可能的類別數, yij 是一個二值指標, 表示類別 j 是否是輸入實例 xi 的真實類別. pij 為模型或分類器預測輸入實例 xi 屬於類別 j 的概率.
如果只有兩類 {0, 1}, 則對數損失函數的公式簡化為
這時, yi 為輸入實例 xi 的真實類別, pi 為預測輸入實例 xi 屬於類別 1 的概率. 對所有樣本的對數損失表示對每個樣本的對數損失的平均值, 對於完美的分類器, 對數損失為 0 .
Python 實現
采用自定義 logloss 函數和 scikit-learn 庫中 sklearn.metrics.log_loss 函數兩種方式實現對數損失, 如下所示:
#!/usr/bin/env python # -*- coding: utf8 -*- #author: klchang #date: 2018.6.23
# y_true: list, the true labels of input instances # y_pred: list, the probability when the predicted label of input instances equals to 1
def logloss(y_true, y_pred, eps=1e-15): import numpy as np # Prepare numpy array data
y_true = np.array(y_true) y_pred = np.array(y_pred) assert (len(y_true) and len(y_true) == len(y_pred)) # Clip y_pred between eps and 1-eps
p = np.clip(y_pred, eps, 1-eps) loss = np.sum(- y_true * np.log(p) - (1 - y_true) * np.log(1-p)) return loss / len(y_true) def unitest(): y_true = [0, 0, 1, 1] y_pred = [0.1, 0.2, 0.7, 0.99] print ("Use self-defined logloss() in binary classification, the result is {}".format(logloss(y_true, y_pred))) from sklearn.metrics import log_loss print ("Use log_loss() in scikit-learn, the result is {} ".format(log_loss(y_true, y_pred))) if __name__ == '__main__': unitest()
注: 在實現時, 加入參數 eps, 避免因預測概率輸出為 0 或 1 而導致的計算錯誤的情況; 對數損失函數的輸入參數 y_pred 為當預測實例屬於類 1 時的概率; 對數損失采用自然對數計算結果.
參考資料
1. Log Loss. http://wiki.fast.ai/index.php/Log_Loss
2. Making Sense of Logarithmic Loss. https://www.r-bloggers.com/making-sense-of-logarithmic-loss/
3. What is an intuitive explanation for the log loss function. https://www.quora.com/What-is-an-intuitive-explanation-for-the-log-loss-function
4. log-loss in scikit-learn documentation. http://scikit-learn.org/stable/modules/model_evaluation.html#log-loss
5. sklearn documentation-sklearn.metrics.log_loss. http://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html#sklearn.metrics.log_loss
6. 李航. 統計學習方法. 北京: 清華大學出版社. 2012