sklearn之分類模型混淆矩陣和分類報告


'''
    1.分類模型之混淆矩陣:
            每一行和每一列分別對應樣本輸出中的每一個類別,行表示實際類別,列表示預測類別。
                        A類別    B類別    C類別
                A類別    5        0        0
                B類別    0        6        0
                C類別    0        0        7
            上述矩陣即為理想的混淆矩陣。不理想的混淆矩陣如下:
                        A類別    B類別    C類別
                A類別    3        1        1
                B類別    0        4        2
                C類別    0        0        7
            查准率 = 主對角線上的值 / 該值所在列的和
            召回率 = 主對角線上的值 / 該值所在行的和

    獲取模型分類結果的混淆矩陣的相關API:
            import sklearn.metrics as sm
            sm.confusion_matrix(實際輸出, 預測輸出)->混淆矩陣

    2.分類模型之分類報告:
                sklearn.metrics提供了分類報告相關API,不僅可以得到混淆矩陣,還可以得到交叉驗證查准率、召回率、f1得分的結果,
                可以方便的分析出哪些樣本是異常樣本。

            # 獲取分類報告
            cr = sm.classification_report(實際輸出, 預測輸出)


'''

import numpy as np
import matplotlib.pyplot as mp
import sklearn.naive_bayes as nb
import sklearn.model_selection as ms
import sklearn.metrics as sm

data = np.loadtxt('./ml_data/multiple1.txt', delimiter=',', unpack=False, dtype='f8')
print(data.shape)
x = np.array(data[:, :-1])
y = np.array(data[:, -1])

# 訓練集和測試集的划分    使用訓練集訓練 再使用測試集測試,並繪制測試集樣本圖像
train_x, test_x, train_y, test_y = ms.train_test_split(x, y, test_size=0.25, random_state=7)

# 針對訓練集,做5次交叉驗證,若得分還不錯再訓練模型
model = nb.GaussianNB()
# 精確度
score = ms.cross_val_score(model, train_x, train_y, cv=5, scoring='accuracy')
print('accuracy score=', score)
print('accuracy mean=', score.mean())

# 查准率
score = ms.cross_val_score(model, train_x, train_y, cv=5, scoring='precision_weighted')
print('precision_weighted score=', score)
print('precision_weighted mean=', score.mean())

# 召回率
score = ms.cross_val_score(model, train_x, train_y, cv=5, scoring='recall_weighted')
print('recall_weighted score=', score)
print('recall_weighted mean=', score.mean())

# f1得分
score = ms.cross_val_score(model, train_x, train_y, cv=5, scoring='f1_weighted')
print('f1_weighted score=', score)
print('f1_weighted mean=', score.mean())

# 訓練NB模型,完成分類業務
model.fit(train_x, train_y)
pred_test_y = model.predict(test_x)
# 得到預測輸出,可以與真實輸出作比較,計算預測的精准度(預測正確的樣本數/總測試樣本數)
ac = (test_y == pred_test_y).sum() / test_y.size
print('預測精准度 ac=', ac)

# 獲取混淆矩陣
m = sm.confusion_matrix(test_y, pred_test_y)
print('混淆矩陣為:', m, sep='\n')

# 獲取分類報告
r = sm.classification_report(test_y, pred_test_y)
print('分類報告為:', r, sep='\n')

# 繪制分類邊界線
l, r = x[:, 0].min() - 1, x[:, 0].max() + 1
b, t = x[:, 1].min() - 1, x[:, 1].max() + 1
n = 500
grid_x, grid_y = np.meshgrid(np.linspace(l, r, n), np.linspace(b, t, n))
bg_x = np.column_stack((grid_x.ravel(), grid_y.ravel()))
bg_y = model.predict(bg_x)
grid_z = bg_y.reshape(grid_x.shape)

# 畫圖
mp.figure('NB Classification', facecolor='lightgray')
mp.title('NB Classification', fontsize=16)
mp.xlabel('X', fontsize=14)
mp.ylabel('Y', fontsize=14)
mp.tick_params(labelsize=10)
mp.pcolormesh(grid_x, grid_y, grid_z, cmap='gray')
mp.scatter(test_x[:, 0], test_x[:, 1], s=80, c=test_y, cmap='jet', label='Samples')

mp.legend()
mp.show()

# 畫出混淆矩陣
mp.figure('Confusion Matrix')
mp.xticks([])
mp.yticks([])
mp.imshow(m, cmap='gray')
mp.show()



輸出結果:
(400, 3)
accuracy score= [1.         1.         1.         1.         0.98305085]
accuracy mean= 0.9966101694915255
precision_weighted score= [1.         1.         1.         1.         0.98411017]
precision_weighted mean= 0.996822033898305
recall_weighted score= [1.         1.         1.         1.         0.98305085]
recall_weighted mean= 0.9966101694915255
f1_weighted score= [1.         1.         1.         1.         0.98303199]
f1_weighted mean= 0.9966063988235516
預測精准度 ac= 0.99
混淆矩陣為:
[[22  0  0  0]
 [ 0 27  1  0]
 [ 0  0 25  0]
 [ 0  0  0 25]]
分類報告為:
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00        22
         1.0       1.00      0.96      0.98        28
         2.0       0.96      1.00      0.98        25
         3.0       1.00      1.00      1.00        25

    accuracy                           0.99       100
   macro avg       0.99      0.99      0.99       100
weighted avg       0.99      0.99      0.99       100

  

 

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM