原理
在機器學習中, 混淆矩陣是一個誤差矩陣, 常用來可視化地評估監督學習算法的性能. 混淆矩陣大小為 (n_classes, n_classes) 的方陣, 其中 n_classes 表示類的數量. 這個矩陣的每一行表示真實類中的實例, 而每一列表示預測類中的實例 (Tensorflow 和 scikit-learn 采用的實現方式). 也可以是, 每一行表示預測類中的實例, 而每一列表示真實類中的實例 (Confusion matrix From Wikipedia 中的定義). 通過混淆矩陣, 可以很容易看出系統是否會弄混兩個類, 這也是混淆矩陣名字的由來.
混淆矩陣是一種特殊類型的列聯表(contingency table)或交叉制表(cross tabulation or crosstab). 其有兩維 (真實值 "actual" 和 預測值 "predicted" ), 這兩維都具有相同的類("classes")的集合. 在列聯表中, 每個維度和類的組合是一個變量. 列聯表以表的形式, 可視化地表示多個變量的頻率分布.
使用混淆矩陣( scikit-learn 和 Tensorflow)
下面先介紹在 scikit-learn 和 tensorflow 中計算混淆矩陣的 API (Application Programming Interface) 接口函數, 然后在一個示例中, 使用這兩個 API 函數.
scikit-learn 混淆矩陣函數 sklearn.metrics.confusion_matrix API 接口
skearn.metrics.confusion_matrix( y_true, # array, Gound true (correct) target values
y_pred, # array, Estimated targets as returned by a classifier
labels=None, # array, List of labels to index the matrix.
sample_weight=None # array-like of shape = [n_samples], Optional sample weights
)
在 scikit-learn 中, 計算混淆矩陣用來評估分類的准確度.
按照定義, 混淆矩陣 C 中的元素 Ci,j 等於真實值為組 i , 而預測為組 j 的觀測數(the number of observations). 所以對於二分類任務, 預測結果中, 正確的負例數(true negatives, TN)為 C0,0; 錯誤的負例數(false negatives, FN)為 C1,0; 真實的正例數為 C1,1; 錯誤的正例數為 C0,1.
如果 labels 為 None, scikit-learn 會把在出現在 y_true 或 y_pred 中的所有值添加到標記列表 labels 中, 並排好序.
Tensorflow 混淆矩陣函數 tf.confusion_matrix API 接口
tf.confusion_matrix( labels, # 1-D Tensor of real labels for the classification task
predictions, # 1-D Tensor of predictions for a givenclassification
num_classes=None, # The possible number of labels the classification task can have
dtype=tf.int32, # Data type of the confusion matrix
name=None, # Scope name
weights=None, # An optional Tensor whose shape matches predictions
)
Tensorflow tf.confusion_matrix 中的 num_classes 參數的含義, 與 scikit-learn sklearn.metrics.confusion_matrix 中的 labels 參數相近, 是與標記有關的參數, 表示類的總個數, 但沒有列出具體的標記值. 在 Tensorflow 中一般是以整數作為標記, 如果標記為字符串等非整數類型, 則需先轉為整數表示. 如果 num_classes 參數為 None, 則把 labels 和 predictions 中的最大值 + 1, 作為 num_classes 參數值.
tf.confusion_matrix 的 weights 參數和 sklearn.metrics.confusion_matrix 的 sample_weight 參數的含義相同, 都是對預測值進行加權, 在此基礎上, 計算混淆矩陣單元的值.
使用示例
#!/usr/bin/env python # -*- coding: utf8 -*-
"""
Author: klchang
Description:
A simple example for tf.confusion_matrix and sklearn.metrics.confusion_matrix.
Date: 2018.9.8
"""
from __future__ import print_function import tensorflow as tf import sklearn.metrics y_true = [1, 2, 4] y_pred = [2, 2, 4] # Build graph with tf.confusion_matrix operation
sess = tf.InteractiveSession() op = tf.confusion_matrix(y_true, y_pred) op2 = tf.confusion_matrix(y_true, y_pred, num_classes=6, dtype=tf.float32, weights=tf.constant([0.3, 0.4, 0.3])) # Execute the graph
print ("confusion matrix in tensorflow: ") print ("1. default: \n", op.eval()) print ("2. customed: \n", sess.run(op2))
sess.close() # Use sklearn.metrics.confusion_matrix function
print ("\nconfusion matrix in scikit-learn: ") print ("1. default: \n", sklearn.metrics.confusion_matrix(y_true, y_pred)) print ("2. customed: \n", sklearn.metrics.confusion_matrix(y_true, y_pred, labels=range(6), sample_weight=[0.3, 0.4, 0.3]))
參考資料
1. Confusion matrix. In Wikipedia, The Free Encyclopedia. https://en.wikipedia.org/wiki/Confusion_matrix
2. Contingency table. In Wikipedia, The Free Encyclopedia. https://en.wikipedia.org/wiki/Contingency_table
3. Tensorflow API - tf.confusion_matrix. https://www.tensorflow.com/api_docs/python/tf/confusion_matrix
4. scikit-learn API - sklearn.metrics.confusion_matrix. http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
