python畫混淆矩陣(confusion matrix)


混淆矩陣(Confusion Matrix),是一種在深度學習中常用的輔助工具,可以讓你直觀地了解你的模型在哪一類樣本里面表現得不是很好。

如上圖,我們就可以看到,有一個樣本原本是0的,卻被預測成了1,還有一個,原本是2的,卻被預測成了0。

 

簡單介紹作用后,下面上代碼:

import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

導入需要的包,如果有一些包沒有,pip一下就可以了。

sns.set()
f,ax=plt.subplots()
y_true = [0,0,1,2,1,2,0,2,2,0,1,1]
y_pred = [1,0,1,2,1,0,0,2,2,0,1,1]
C2= confusion_matrix(y_true, y_pred, labels=[0, 1, 2])
print(C2) #打印出來看看
sns.heatmap(C2,annot=True,ax=ax) #畫熱力圖

ax.set_title('confusion matrix') #標題
ax.set_xlabel('predict') #x軸
ax.set_ylabel('true') #y軸

下面就是結果:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM