python画混淆矩阵(confusion matrix)


混淆矩阵(Confusion Matrix),是一种在深度学习中常用的辅助工具,可以让你直观地了解你的模型在哪一类样本里面表现得不是很好。

如上图,我们就可以看到,有一个样本原本是0的,却被预测成了1,还有一个,原本是2的,却被预测成了0。

 

简单介绍作用后,下面上代码:

import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

导入需要的包,如果有一些包没有,pip一下就可以了。

sns.set()
f,ax=plt.subplots()
y_true = [0,0,1,2,1,2,0,2,2,0,1,1]
y_pred = [1,0,1,2,1,0,0,2,2,0,1,1]
C2= confusion_matrix(y_true, y_pred, labels=[0, 1, 2])
print(C2) #打印出来看看
sns.heatmap(C2,annot=True,ax=ax) #画热力图

ax.set_title('confusion matrix') #标题
ax.set_xlabel('predict') #x轴
ax.set_ylabel('true') #y轴

下面就是结果:

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM