分類問題(一)MINST數據集與二元分類器


分類問題

在機器學習中,主要有兩大類問題,分別是分類和回歸。下面我們先主講分類問題。

 

MINST

這里我們會用MINST數據集,也就是眾所周知的手寫數字集,機器學習中的 Hello World。sk-learn 提供了用於直接下載此數據集的方法:

from sklearn.datasets import fetch_openml

minst = fetch_openml('mnist_784', version=1)
minst.keys()
>dict_keys(['data', 'target', 'feature_names', 'DESCR', 'details', 'categories', 'url'])

像這種sk-learn 下載的數據集,一般都有相似的字典結構,包括:

  • DESCR:描述數據集
  • data:包含一個數組,每行是一條數據,每列是一個特征
  • target:包含一個數組,為label值

我們看一下這些數組:

X,y = minst['data'],minst['target']
X.shape, y.shape
>((70000, 784), (70000,))

可以看到一共有 70000 張圖片,每張圖片包含784個特征。這是因為每張圖包含28×28像素點,每個特征代表的是此像素點強度,取值范圍從0(白)到255(黑)。我們先看一下其中一條數據。首先獲取一條數據的特征向量,然后reshape到一個28×28 的數組,最后用matplotlib 的imshow() 方法顯示即可:

import matplotlib as mpl
import matplotlib.pyplot as plt

some_digit = X[0]
some_digit_image = some_digit.reshape(28, 28)

plt.imshow(some_digit_image, cmap = mpl.cm.binary, interpolation="nearest")
plt.axis("off")
plt.show()

從圖片來看,這個應該是數字5,我們可以通過label 進行驗證:

y[0]
>'5'

可以看到這個label的數值是 string,我們需要將它們轉換成int:

import numpy as np

y = y.astype(np.uint8)
>array([5, 0, 4, ..., 4, 5, 6], dtype=uint8)

現在,我們初步了解了數據集。在訓練之前,必須要將數據集分為訓練集與測試集。這個MINST數據集已經做好了划分,前60000 為訓練接,后10000為測試集,直接取用即可:

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

這個訓練集已經做過了shuffle,基本可以確保k-折交叉驗證的各個集合基本相似(例如不會出現某個折中缺失一些數字)。另一方面,有些學習算法對於訓練數據的順序比較敏感,所以對數據集進行shuffle的好處是避免數據的順序對訓練造成的影響。

 

訓練二元分類器

我們先簡化此問題,僅讓我們的模型判斷一個數字,例如5。這樣的分類器稱為二元分類器,僅能將數據分為兩個類別:數字5和非數字5。下面我們為這類分類器創建label:

y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

現在我們選擇一個分類器並進行訓練,可以先從一個隨機梯度下降(Stochastic Gradient Descent,SGD) 分類器開始,使用sk-learn的SGDClassifer 類。這個分類器的優點是:能夠高效地處理非常大的數據集。因為它每次均僅處理一條數據(也正因如此,SGD非常適合online learning 場景)。下面創建一個SGDClassifer 並在整個訓練集上進行訓練:

from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

SGDClassifier在訓練時會隨機選擇數據,如果要復現結果的話,則需要手動設置random_state 參數。現在我們可以使用已訓練好的模型進行預測一個手寫數字是否是5:

sgd_clf.predict([X_test[0], X_test[1], X_test[2]])
>array([False, False, False])

看起來結果還不錯,我們稍后評估一下這個模型的性能。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM