分類問題
在機器學習中,主要有兩大類問題,分別是分類和回歸。下面我們先主講分類問題。
MINST
這里我們會用MINST數據集,也就是眾所周知的手寫數字集,機器學習中的 Hello World。sk-learn 提供了用於直接下載此數據集的方法:
from sklearn.datasets import fetch_openml minst = fetch_openml('mnist_784', version=1) minst.keys() >dict_keys(['data', 'target', 'feature_names', 'DESCR', 'details', 'categories', 'url'])
像這種sk-learn 下載的數據集,一般都有相似的字典結構,包括:
- DESCR:描述數據集
- data:包含一個數組,每行是一條數據,每列是一個特征
- target:包含一個數組,為label值
我們看一下這些數組:
X,y = minst['data'],minst['target'] X.shape, y.shape >((70000, 784), (70000,))
可以看到一共有 70000 張圖片,每張圖片包含784個特征。這是因為每張圖包含28×28像素點,每個特征代表的是此像素點強度,取值范圍從0(白)到255(黑)。我們先看一下其中一條數據。首先獲取一條數據的特征向量,然后reshape到一個28×28 的數組,最后用matplotlib 的imshow() 方法顯示即可:
import matplotlib as mpl import matplotlib.pyplot as plt some_digit = X[0] some_digit_image = some_digit.reshape(28, 28) plt.imshow(some_digit_image, cmap = mpl.cm.binary, interpolation="nearest") plt.axis("off") plt.show()
從圖片來看,這個應該是數字5,我們可以通過label 進行驗證:
y[0] >'5'
可以看到這個label的數值是 string,我們需要將它們轉換成int:
import numpy as np y = y.astype(np.uint8) >array([5, 0, 4, ..., 4, 5, 6], dtype=uint8)
現在,我們初步了解了數據集。在訓練之前,必須要將數據集分為訓練集與測試集。這個MINST數據集已經做好了划分,前60000 為訓練接,后10000為測試集,直接取用即可:
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
這個訓練集已經做過了shuffle,基本可以確保k-折交叉驗證的各個集合基本相似(例如不會出現某個折中缺失一些數字)。另一方面,有些學習算法對於訓練數據的順序比較敏感,所以對數據集進行shuffle的好處是避免數據的順序對訓練造成的影響。
訓練二元分類器
我們先簡化此問題,僅讓我們的模型判斷一個數字,例如5。這樣的分類器稱為二元分類器,僅能將數據分為兩個類別:數字5和非數字5。下面我們為這類分類器創建label:
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
現在我們選擇一個分類器並進行訓練,可以先從一個隨機梯度下降(Stochastic Gradient Descent,SGD) 分類器開始,使用sk-learn的SGDClassifer 類。這個分類器的優點是:能夠高效地處理非常大的數據集。因為它每次均僅處理一條數據(也正因如此,SGD非常適合online learning 場景)。下面創建一個SGDClassifer 並在整個訓練集上進行訓練:
from sklearn.linear_model import SGDClassifier sgd_clf = SGDClassifier(random_state=42) sgd_clf.fit(X_train, y_train_5)
SGDClassifier在訓練時會隨機選擇數據,如果要復現結果的話,則需要手動設置random_state 參數。現在我們可以使用已訓練好的模型進行預測一個手寫數字是否是5:
sgd_clf.predict([X_test[0], X_test[1], X_test[2]])
>array([False, False, False])
看起來結果還不錯,我們稍后評估一下這個模型的性能。