scikit-learn 多分類混淆矩陣


注:有些markdown語法沒渲染出來,可以簡書查看:scikit-learn 多分類混淆矩陣

前面

sklearn.metrics.multilabel_confusion_matrixscikit-learn 0.21 新增的一個函數。看名字可知道是用來計算多標簽的混淆矩陣的。不過也可以用它來計算多分類的混淆矩陣。MCM將多分類數據轉化為2分類問題,采用one-vs-rest策略,即某一類為正樣本,其余類別為負樣本。每一類都作為正樣本,計算混淆矩陣。按標簽的順序返回所有。
MCM 返回的每一個二分類混淆矩陣中,TN 在 [0, 0] ,FN 在 [1, 0] , TP 在[1,1], FP 在 [0, 1] , 即

TN FP
FN TP

官方例子

## 如果導入報錯,檢查一下 sk-learn version >= 0.21
>>> from sklearn.metrics import multilabel_confusion_matrix
>>> y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
>>> y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]
>>> mcm = multilabel_confusion_matrix(y_true, y_pred,
...                             labels=["ant", "bird", "cat"])
>>> mcm
array([[[3, 1],
        [0, 2]],
       [[5, 0],
        [1, 0]],
       [[2, 1],
        [1, 2]]])

以第一個類別 ‘ant’ 為例,預測對的有2個,它的負樣本,'bird' 和 'cat' 預測對的有3個(‘bird’ 預測成 ‘cat’, 也算對的,因為它們是一類,都是負樣本。)負樣本預測成正樣本的有一個。

評估指標

每一類的TP, FP等可以提取通過:

>>> tp = mcm[:, 1, 1]
>>> tn = mcm[:, 0, 0]
>>> fn = mcm[:, 1, 0]
>>> tp, tn
(array([2, 0, 2], dtype=int64), array([3, 5, 2], dtype=int64))

這里有幾個常用的評估指標:

  1. 敏感性(sensitivity)也叫召回率(recall),也叫查全率。這個指標是看一下正樣本中預測對的占總正樣本的比例。也可以說成預測器對正樣本的敏感性,越大,說明預測器對正樣本越敏感。
    $$ sn=\frac{tp}{tp+fn} $$
  2. 特異性(specificity)這個和敏感性相反,敏感性算正樣本的,而特異性算的是負樣本的。換句話說,它是指負樣本的敏感性。畢竟你的預測器,不能僅僅是對正樣本敏感,負樣本,就隨意了。所以需要評估一下預測器對負樣本的敏感性。
    $$sp=\frac{tn}{tn+fp}$$
  3. 查准率(precision), 這是看你預測為正樣本中預測正確的占總的預測為正樣本的比例。
    $$precision=\frac{tp}{tp+fp}$$
  4. f1值,一般而言,查全率和查准率是不能同時都很大的。舉個例子:你現在有100個A和100個B,你用現在訓練好的模型去預測A,預測到有80個A。但是這其中75個是正確的A。也就是說查准率是$75/80=0.9375%$,查全率是$75/100=0.75$。你覺得查全率太低,你繼續改進模型。又進行了一次預測,這次預測到了95個A。其中預測正確的有85個,即查全率:$85/100=0.85$,增加了0.1,但是查准率:$85/95=0.895$下降了。你想查得越多,就更容易產生誤差。為了照顧兩頭,使得兩個指標都有不錯得值,就有了f1值:
    $$F1 = \frac{2 * (precision * recall)}{ (precision + recall)}$$

很容易通過代碼獲得多分類中每一類的評價指標值:

>>> sn = tp / (tp + fn) ## 其它同理
>>> sn
 array([1.        , 0.        , 0.66666667])x xz

利用one-vs-rest將多分類轉化為二分類問題時,往往會丟失一些信息。在負樣本中有多個類別,但不管在負樣本中否預測到其本身的標簽,只要不是預測為正樣本標簽就是正確的。所以不能很好的評價rest里的預測效果。想要更好的評價多分類,應考慮下宏平均或者微平均。

參考

sklearn.metrics.multilabel_confusion_matrix

原文:scikit-learn 多分類混淆矩陣


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM