基於sklearn的常用分類任務指標Python實現


基於sklearn的常用分類任務指標Python實現

一、摘要

分類任務常用指標包含混淆矩陣、每類分類精度、平均分類精度、總體分類精度、f1-score等。 Python的sklearn.metrics 模塊覆蓋了分類任務中大部分常用的驗證指標, 本文選擇其中幾種評價指標展示代碼片段,供讀者使用。 基於tensorflow-1.0與mnist數據集做demo展示並列舉實驗結果。 文末附有sklearn.metrics模塊的相關資料鏈接,方便高端玩家深入探索。

二、本文包含的評價指標

混淆矩陣(Confusion Matrix,CM) 
每類別分類精度 
每類別召回率 
平均分類精度(Average Accuracy,AA) 
總體分類精度(Overall Accuracy,OA)

三、功能代碼片段展示

代碼在tensorflow-1.0、Python3.5環境下通過測試,tf1.0版本API改動較大,1.0以下版本tensorflow可能不能通過測試,精力有限,其他環境尚未做測試。

 1 from sklearn import metrics
 2 import numpy as np
 3 #####
 4 # Do classification task, 
 5 # then get the ground truth and the predict label named y_true and y_pred
 6 classify_report = metrics.classification_report(y_true, y_pred)
 7 confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
 8 overall_accuracy = metrics.accuracy_score(y_true, y_pred)
 9 acc_for_each_class = metrics.precision_score(y_true, y_pred, average=None)
10 average_accuracy = np.mean(acc_for_each_class)
11 score = metrics.accuracy_score(y_true, y_pred)
12 print('classify_report : \n', classify_report)
13 print('confusion_matrix : \n', confusion_matrix)
14 print('acc_for_each_class : \n', acc_for_each_class)
15 print('average_accuracy: {0:f}'.format(average_accuracy))
16 print('overall_accuracy: {0:f}'.format(overall_accuracy))
17 print('score: {0:f}'.format(score))

 

四、實驗結果展示

本文基於tensorflow-1.0框架與mnist數據集,使用線性分類器與卷積神經網絡分類並使用上文提到的代碼片段展示分類性能。

分類性能結果直觀,排列清晰,便於二次使用。

1. 線性分類器分類報告:

 

2. 線性分類器混淆矩陣與其他分類指標展示:

 

3. 卷積神經網絡每層參數顯示:

 

4. 卷積神經網絡分類報告:

 

5. 卷積神經網絡混淆矩陣與其他分類指標展示:

 

 

五、代碼示例

使用類似如下的代碼片段可以直觀查看tensor相關內容

 1 print(some_tensor.op.name, ' ', some_tensor.get_shape().as_list()) 

代碼太長,這里就不粘貼了。代碼來源:https://github.com/JiJingYu/tensorflow-exercise/blob/master/mnist_test.py

 六、總結

不得不說sklearn是個全面的Python模塊,常用的機器學習方法以及評價准則都能從中找到函數與例程。同時,tensorflow作為google親兒子,發展勁頭勢不可擋。

sklearn 官網:http://scikit-learn.org/stable/index.html

一份高質量 sklearn tutorial:https://github.com/jakevdp/sklearn_tutorial

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM