分類問題(六)誤差分析


誤差分析

如果上一章中的分類器是一個真實的項目,則我們最好是要遵循機器學習項目步驟:探索數據、准備數據、嘗試多個模型、列出表現最好的幾個模型、使用GridSearchCV對超參數進行調優、盡可能實現自動化。現在,假設我們已經有了一個性能還不錯的模型,接下來我們要找一些辦法去優化、提升它。其中一個辦法是就分析這個模型產生的各種不同類型的誤差、差錯。

首先我們看一下混淆矩陣,我們需要先使用cross_val_predict() 做預測,然后調用confusion_matrix() 計算:

y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
>array([[5576,    0,   21,    6,    9,   43,   37,    6,  224,    1],
       [   0, 6398,   38,   23,    4,   44,    4,    8,  213,   10],
       [  26,   27, 5242,   90,   71,   26,   62,   36,  371,    7],
       [  24,   17,  117, 5220,    2,  208,   28,   40,  405,   70],
       [  12,   14,   48,   10, 5192,   10,   36,   26,  330,  164],
       [  28,   15,   33,  166,   55, 4437,   76,   14,  538,   59],
       [  30,   14,   41,    2,   43,   95, 5560,    4,  128,    1],
       [  21,    9,   52,   27,   51,   12,    3, 5693,  188,  209],
       [  17,   63,   46,   90,    3,  125,   25,   10, 5429,   43],
       [  23,   18,   31,   66,  116,   32,    1,  179,  377, 5106]])

 

可以看到有很多的數字,為了方便一般我們會將這種混淆矩陣以圖片的方式展示出來,使用Matplotlib 的matshow() 方法:

plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()

這個混淆矩陣看起來還不錯,因為圖片基本都在主對角線上,也就是說它們被正確地分類到的所屬的類別。其中第5個的顏色相較其他數字稍深,說明可能有兩種問題:

  1. 數字5的圖片在數據集中較少
  2. 模型在數字5上的表現不如在其他數字上好

事實上我們可以確認這兩種問題都存在。

下面我們將關注點放在誤差上。首先我們需要將混淆矩陣中的每個值均除以對應類別的總數,用來對比誤差率(之前的混淆矩陣中,全部是精確的錯誤數,並不容易進行觀察與判斷):

row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx/row_sums

 

然后我們將主對角線填充0,僅保留誤差,最后畫出結果:

np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()

 

現在我們可以明顯地看到分類器產生的誤差。這里回顧一下,行代表的是實際類別,列代表的是預測類別。可以明顯地看到第8列非常亮,它告訴我們的是:很多圖片被錯誤地分類成了數字8。然而,第8行卻並不差,說明:數字8一般都被正確地分類為了數字8。在圖中還可以看到混淆矩陣並不一定對稱。還可以看到數字3與數字5經常被混淆(行列均是),將數字3預測為數字5,並將數字5預測為數字3。

通過分析混淆矩陣,經常可以給我們提供一個更深層的視角觀察模型表現,並提供我們提升模型的思路。在上圖中,我們似乎需要將更多的精力花在減少錯誤預測的數字8(false 8)。例如,我們可以獲取更多的看起來像數字8但不是數字8的訓練數據,這樣可以讓分類器學習如何將它們與真正的數字8區分開來。或者也可以構造一些新的屬性幫助分類器,例如,寫一個算法,計算回環的數目(例如,8有兩個,6有一個,5沒有)。或者可以對圖片進行預處理(例如用sk-image,pillow,或OpenCV),讓一些模式更突出的顯示出來(例如回環)。

分析單獨的各個誤差也是一個很好的辦法,它可以告訴我們分類器做了什么,並且為什么分類失敗。不過這個過程會更難,並且更耗時。例如,我們可以畫出一些數字3與數字5:

def plot_digit(data):
    image = data.reshape(28, 28)
    plt.imshow(image, cmap = mpl.cm.binary,
               interpolation="nearest")
    plt.axis("off")

def plot_digits(instances, images_per_row=10, **options):
    size = 28
    images_per_row = min(len(instances), images_per_row)
    images = [instance.reshape(size,size) for instance in instances]
    n_rows = (len(instances) - 1) // images_per_row + 1
    row_images = []
    n_empty = n_rows * images_per_row - len(instances)
    images.append(np.zeros((size, size * n_empty)))
    for row in range(n_rows):
        rimages = images[row * images_per_row : (row + 1) * images_per_row]
        row_images.append(np.concatenate(rimages, axis=1))
    image = np.concatenate(row_images, axis=0)
    plt.imshow(image, cmap = mpl.cm.binary, **options)
    plt.axis("off")

cl_a, cl_b = 3, 5
X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]
X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]
X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]
X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]

plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
plt.show()

 

左邊的兩個5×5 的圖展示的是被預測為“數字3“的圖,右邊的兩幅5×5的圖顯示的是被預測為”數字5“的圖。左下角與右上角的兩幅5×5的圖均是分類錯誤的圖片。從這些圖可以看出,分類器在分類某些圖片的時候,確實受到了手寫不規范的影響(例如左下角第1行第2列那個5,即使是人為分辨,也難以分辨為5還是3)。然而,除了少部分手寫的不清晰外,其他大部分的數字是能夠人為分辨的,所以光看圖很難理解為什么分類器在這些數字上分類錯誤。其實它的原因是由於我們使用了一個簡單的SGDClassifier,它是一個線性模型。它做的事是:給每個像素點分配一個權重,在它看到一張新圖片時,它僅會將所有帶權的像素點強度累加起來,最后會為每個類別生成一個分數。所以,由於數字3與數字5的像素點相差的不多,這個模型會很容易將它們混淆。

3與5的主要區別是連接上方橫線與下方灣溝的那條短線。如果我們在寫一個3時,把這條短線稍微靠了左邊,那這個分類器可能就會將它分類成5,反之亦然。換句話說,這個分類器對圖片的平移與旋轉非常敏感。所以其中一個減少3與5混淆不清的方法是預先處理圖片,並確保它們在正中間,且沒有旋轉。這個可能會對減少誤差有所幫助。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM