誤差分析
如果上一章中的分類器是一個真實的項目,則我們最好是要遵循機器學習項目步驟:探索數據、准備數據、嘗試多個模型、列出表現最好的幾個模型、使用GridSearchCV對超參數進行調優、盡可能實現自動化。現在,假設我們已經有了一個性能還不錯的模型,接下來我們要找一些辦法去優化、提升它。其中一個辦法是就分析這個模型產生的各種不同類型的誤差、差錯。
首先我們看一下混淆矩陣,我們需要先使用cross_val_predict() 做預測,然后調用confusion_matrix() 計算:
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3) conf_mx = confusion_matrix(y_train, y_train_pred) >array([[5576, 0, 21, 6, 9, 43, 37, 6, 224, 1], [ 0, 6398, 38, 23, 4, 44, 4, 8, 213, 10], [ 26, 27, 5242, 90, 71, 26, 62, 36, 371, 7], [ 24, 17, 117, 5220, 2, 208, 28, 40, 405, 70], [ 12, 14, 48, 10, 5192, 10, 36, 26, 330, 164], [ 28, 15, 33, 166, 55, 4437, 76, 14, 538, 59], [ 30, 14, 41, 2, 43, 95, 5560, 4, 128, 1], [ 21, 9, 52, 27, 51, 12, 3, 5693, 188, 209], [ 17, 63, 46, 90, 3, 125, 25, 10, 5429, 43], [ 23, 18, 31, 66, 116, 32, 1, 179, 377, 5106]])
可以看到有很多的數字,為了方便一般我們會將這種混淆矩陣以圖片的方式展示出來,使用Matplotlib 的matshow() 方法:
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()
這個混淆矩陣看起來還不錯,因為圖片基本都在主對角線上,也就是說它們被正確地分類到的所屬的類別。其中第5個的顏色相較其他數字稍深,說明可能有兩種問題:
- 數字5的圖片在數據集中較少
- 模型在數字5上的表現不如在其他數字上好
事實上我們可以確認這兩種問題都存在。
下面我們將關注點放在誤差上。首先我們需要將混淆矩陣中的每個值均除以對應類別的總數,用來對比誤差率(之前的混淆矩陣中,全部是精確的錯誤數,並不容易進行觀察與判斷):
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx/row_sums
然后我們將主對角線填充0,僅保留誤差,最后畫出結果:
np.fill_diagonal(norm_conf_mx, 0) plt.matshow(conf_mx, cmap=plt.cm.gray) plt.show()
現在我們可以明顯地看到分類器產生的誤差。這里回顧一下,行代表的是實際類別,列代表的是預測類別。可以明顯地看到第8列非常亮,它告訴我們的是:很多圖片被錯誤地分類成了數字8。然而,第8行卻並不差,說明:數字8一般都被正確地分類為了數字8。在圖中還可以看到混淆矩陣並不一定對稱。還可以看到數字3與數字5經常被混淆(行列均是),將數字3預測為數字5,並將數字5預測為數字3。
通過分析混淆矩陣,經常可以給我們提供一個更深層的視角觀察模型表現,並提供我們提升模型的思路。在上圖中,我們似乎需要將更多的精力花在減少錯誤預測的數字8(false 8)。例如,我們可以獲取更多的看起來像數字8但不是數字8的訓練數據,這樣可以讓分類器學習如何將它們與真正的數字8區分開來。或者也可以構造一些新的屬性幫助分類器,例如,寫一個算法,計算回環的數目(例如,8有兩個,6有一個,5沒有)。或者可以對圖片進行預處理(例如用sk-image,pillow,或OpenCV),讓一些模式更突出的顯示出來(例如回環)。
分析單獨的各個誤差也是一個很好的辦法,它可以告訴我們分類器做了什么,並且為什么分類失敗。不過這個過程會更難,並且更耗時。例如,我們可以畫出一些數字3與數字5:
def plot_digit(data): image = data.reshape(28, 28) plt.imshow(image, cmap = mpl.cm.binary, interpolation="nearest") plt.axis("off") def plot_digits(instances, images_per_row=10, **options): size = 28 images_per_row = min(len(instances), images_per_row) images = [instance.reshape(size,size) for instance in instances] n_rows = (len(instances) - 1) // images_per_row + 1 row_images = [] n_empty = n_rows * images_per_row - len(instances) images.append(np.zeros((size, size * n_empty))) for row in range(n_rows): rimages = images[row * images_per_row : (row + 1) * images_per_row] row_images.append(np.concatenate(rimages, axis=1)) image = np.concatenate(row_images, axis=0) plt.imshow(image, cmap = mpl.cm.binary, **options) plt.axis("off") cl_a, cl_b = 3, 5 X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)] X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)] X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)] X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)] plt.figure(figsize=(8,8)) plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5) plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5) plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5) plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5) plt.show()
左邊的兩個5×5 的圖展示的是被預測為“數字3“的圖,右邊的兩幅5×5的圖顯示的是被預測為”數字5“的圖。左下角與右上角的兩幅5×5的圖均是分類錯誤的圖片。從這些圖可以看出,分類器在分類某些圖片的時候,確實受到了手寫不規范的影響(例如左下角第1行第2列那個5,即使是人為分辨,也難以分辨為5還是3)。然而,除了少部分手寫的不清晰外,其他大部分的數字是能夠人為分辨的,所以光看圖很難理解為什么分類器在這些數字上分類錯誤。其實它的原因是由於我們使用了一個簡單的SGDClassifier,它是一個線性模型。它做的事是:給每個像素點分配一個權重,在它看到一張新圖片時,它僅會將所有帶權的像素點強度累加起來,最后會為每個類別生成一個分數。所以,由於數字3與數字5的像素點相差的不多,這個模型會很容易將它們混淆。
3與5的主要區別是連接上方橫線與下方灣溝的那條短線。如果我們在寫一個3時,把這條短線稍微靠了左邊,那這個分類器可能就會將它分類成5,反之亦然。換句話說,這個分類器對圖片的平移與旋轉非常敏感。所以其中一個減少3與5混淆不清的方法是預先處理圖片,並確保它們在正中間,且沒有旋轉。這個可能會對減少誤差有所幫助。