3.1 MNIST
本章介紹分類,使用MNIST數據集。該數據集包含七萬個手寫數字圖片。使用Scikit-Learn函數即可下載該數據集:
>>> from sklearn.datasets import fetch_mldata
>>> mnist = fetch_mldata('MNIST original')
>>> X, y = mnist["data"], mnist["target"]
>>> X.shape
(70000, 784)
>>> y.shape
(70000,)
70000張圖片,每張圖片有784個特征,代表28*28個像素點。每個像素點取值從0(白)到255(黑)。並且前60000張是訓練集,后10000張是測試集。
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
訓練集是按照數字的順序進行排序的,我們需要將順序打亂,這可以保證交叉驗證的k個部分是一致的(我們不希望某一部分缺少一些數字)。此外,一些算法對訓練集的順序是敏感的,在一行出現很多相似樣本時會表現很差。打算訓練集就是為了防止這一情況發生。有時候打亂順序是不明智的——例如,處理的是時序數據(time series data,比如股價、天氣),這將在后面章節討論。
import numpy as np # 打亂訓練集數據順序 shuffle_index = np.random.permutation(60000) X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
3.2 訓練二元分類器(Training a Binary Classifier)
首先將問題簡化,訓練一個二元分類器。比如只判斷圖像是5或者不是5。目標向量可通過如下代碼創建:
y_train_5 = (y_train == 5) # True for all 5s, False for all other digits. y_test_5 = (y_test == 5)
作者選擇了隨機梯度下降(Stochastic Gradient Descent,SGD。梯度下降可參考:梯度下降求解線性回歸)分類器,Scikit-Learn’s SGDClassifier。
3.3 性能評估(Performance Measures)
3.3.1 交叉驗證計算准確率(Measuring Accuracy Using Cross-Validation)
>>> from sklearn.model_selection import cross_val_score >>> cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy") array([ 0.9502 , 0.96565, 0.96495])
得了95%以上的正確率,這似乎很不錯了。事實上,我們可以定義一個很弱智的分類器,該分類器把所有圖像都識別為不是5,該分類器也能有90%的正確率,因為5的圖像只占總數的10%。這就很尷尬了。
因此,對於分類問題來說,准確率通常不是最好的衡量指標,特別是處理傾斜數據集時(skewed datasets,例如一些類別的頻率明顯高於其它類別)。
3.3.2 混淆矩陣(Confusion Matrix)
>>> from sklearn.model_selection import cross_val_predict >>> y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3) >>> from sklearn.metrics import confusion_matrix >>> confusion_matrix(y_train_5, y_train_pred) array([[53272, 1307], [ 1077, 4344]])

每行代表真實類別,每列代表預測類別。第一行是真實值為非5的圖像(負類別,the negative class):53,272個樣本正確分類為非5(這被稱作true negatives,TN),其余的1,307個被錯誤分類為5(false positives,FP)。第二行是真實值為5的圖像:1,077個圖片被錯誤分類為非5(false negatives),剩下的4,344個被正確分類為5(true positives)。
定義精度(precision)和召回率(recall):
\begin{align*}
precision &= \frac{TP}{TP + FP} \\
recall &= \frac{TP}{TP + FN} \\
\end{align*}
3.2.3 精度和召回率(Precision and Recall)
>>> from sklearn.metrics import precision_score, recall_score >>> precision_score(y_train_5, y_pred) # == 4344 / (4344 + 1307) 0.76871350203503808 >>> recall_score(y_train_5, y_train_pred) # == 4344 / (4344 + 1077) 0.79136690647482011
現在可以看出,我們的分類是表現的並不好,盡管准確率(accuracy)是95%以上。當分類器認為一個圖像是5時,這只有不到77%的情況下是正確的。此外,只檢測大了79%的5。
可以將精度和召回率組合成一個被稱為$F_1$值的指標,這在比較兩個分類器時很方便。$F_1$值是精度和召回率的調和平均數(harmonic mean)。普通的平均數處理所有值都是均等的,調和平均數給予小值更高的權重。只有在精度和召回率都比較高的情況下,才會得到比較高的$F_1$值。
\begin{align*}
F_1 = \frac{2}{\frac{1}{precision} + \frac{1}{recall}} = 2 \times \frac{precision \times recall}{precision + recall} = \frac{TP}{TP + \frac{FN + FP}{2}}
\end{align*}
>>> from sklearn.metrics import f1_score >>> f1_score(y_train_5, y_pred) 0.78468208092485547
精度和召回率相近的分類器,傾向於得到較高的$F_1$值。但有時候我們更關心精度,有時候真正看重的是召回率。
例如,訓練一個視頻分類器,檢測出對兒童安全的視頻,這就需要寧缺(低召回率)毋濫(高精度)了。
再比如,你的分類器時檢測扒手的,為了一個壞人都不放過(高召回率),即使精度低一些也可以接受。
不幸的是,魚和熊掌不可兼得:增大召回率造成精度減小,反之亦然。這被稱為精度/召回率權衡(precision/recall tradeoff)。
3.2.4 精度/召回率權衡(precision/recall tradeoff)
首先說明一下SGDClassifier是怎么做分類決策的。對於每一個實例,它都會通過決策函數計算一個分支,如果該分值高於閾值, 就預測該實例為正樣本,反之預測為負樣本。

圖3-3.決策閾值和精度/召回率權衡
雖然Scikit-Learn並不允許直接修改閾值,但可以獲取用於預測的決策分值(decision scores)。
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method="decision_function")
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
plt.xlabel("Threshold")
plt.legend(loc="upper left")
plt.ylim([0, 1])
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.show()

不同閾值下的精度和召回率
另外一個權衡精度和召回率的方式是直接畫出二者圖像:

可以看出,在80%召回率附近,精度開始快速下降。可以在這一下降之前對精度和召回率做一權衡,比如選擇60%的召回率。當然,這取決於具體的項目。
如果有人說:讓我們達到99%的精度。你應該問,基於什么樣的召回率?
如果一個分類器召回率特別低,即使它的精度很高,那也沒什么用。
3.2.5 ROC
ROC(receiver operating characteristic)曲線是另一個二分類器的常用工具。它和精度/召回率曲線類似。不同之處在於,ROC曲線畫出的是不同FPR(false positive rate)下的TPR(true positive rate,這是召回率的別名)。
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
def plot_roc_curve(fpr, tpr, label=None):
plt.plot(fpr, tpr, linewidth=2, label=label)
plt.plot([0, 1], [0, 1], 'k--')
plt.axis([0, 1, 0, 1])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plot_roc_curve(fpr, tpr)
plt.show()

圖3-6.ROC曲線
這也需要進行權衡:召回率(TPR)越高,分類器就會產生越多的錯誤正樣本(FPR)。
分類器好壞的一個度量方式是AUC(area under the curve)。一個完美的分類器,ROC AUC等於1。而一個完全隨機的分類器,ROC AUC等於0.5。
由於ROC曲線和精度/召回率(precision/recall,PR)曲線是如此的相似, 或許存在困惑該如何選取。一般來說,如果正樣本是稀少的,或者相較於錯誤的負樣本,你更關心錯誤的正樣本,那就應該選擇PR曲線。反之,選擇ROC曲線。例如,觀察一下先前的ROC曲線(包括ROC AUC分值),那可能覺得分類器已經相當好了。但這主要是因為負樣本(非5)明顯多於正樣本(5)。與之相反,PR曲線顯示出我們的分類器明顯還有提升的空間(曲線可以更靠近右上角)。
3.3 多標簽分類(Multilabel Classification)
3.4 Multioutput Classification
