學習機器學習默認導入庫:
%matplotlib notebook import numpy as np import matplotlib.pyplot as plt import pandas as pd import mglearn
分析過程:
我們構思了一項任務,要利用鳶尾花的物理測量數據來預測其品種。我們在構建模型時用到了由專家標注過 的測量數據集,專家已經給出了花的正確品種,因此這是一個監督學習問題。一共有三個 品種:setosa、versicolor 或 virginica,因此這是一個三分類問題。在分類問題中,可能的 品種被稱為類別(class),每朵花的品種被稱為它的標簽(label)。
鳶尾花(Iris)數據集包含兩個 NumPy 數組:一個包含數據,在 scikit-learn 中被稱為 X; 一個包含正確的輸出或預期輸出,被稱為 y。數組 X 是特征的二維數組,每個數據點對應 一行,每個特征對應一列。數組 y 是一維數組,里面包含一個類別標簽,對每個樣本都是 一個 0 到 2 之間的整數。
我們將數據集分成訓練集(training set)和測試集(test set),前者用於構建模型,后者用 於評估模型對前所未見的新數據的泛化能力。
我們選擇了 k 近鄰分類算法,根據新數據點在訓練集中距離最近的鄰居來進行預測。該算 法在 KNeighborsClassifier 類中實現,里面既包含構建模型的算法,也包含利用模型進行 預測的算法。我們將類實例化,並設定參數。然后調用 fit 方法來構建模型,傳入訓練數 據(X_train)和訓練輸出(y_trian)作為參數。我們用 score 方法來評估模型,該方法 計算的是模型精度。我們將 score 方法用於測試集數據和測試集標簽,得出模型的精度約 為 97%,也就是說,該模型在測試集上 97% 的預測都是正確的。