鳶尾花數據分類,通過Python實現KNN分類算法。
項目來源:https://aistudio.baidu.com/aistudio/projectdetail/1988428
數據集來源:鳶尾花數據集https://aistudio.baidu.com/aistudio/datasetdetail/91206
1 import numpy as np 2 import pandas as pd 3 import matplotlib as mpl 4 import matplotlib.pyplot as plt 5 6 # 讀取鳶尾花數據集,header參數來指定標題的行。默認為0。如果沒有標題,則使用None 7 # 四個特征分別為花萼長度sepal length,花萼寬度sepal width,花瓣長度petal length,花瓣寬度petal width。鳶尾花的種類,共有3種,分別為山鳶尾Iris Setosa、雜色鳶尾Iris Versicolour、維吉尼亞鳶尾Iris Virginica。 8 data = pd.read_csv('./data/data91206/iris.csv', header=0) 9 # 顯示全部數據 10 # data 11 12 # 顯示前n行的數據,默認n的值為5 13 # data.head() 14 15 # 顯示末尾的n行記錄,默認n的值為5 16 # data.tail() 17 18 # 隨機抽取樣本,默認抽取一條,我們可以通過修改參數來指定抽取樣本的數量 19 data.sample(10) 20 21 # 將類別文本映射為數值類型 22 data['Species'] = data['Species'].map({'versicolor':0,'setosa':1,'virginica':2}) 23 # 刪除不需要的Id列,並改變原來的文本,以下有兩種方法 24 # data.drop('Id', axis=1, inplace = True) 25 data = data.drop('Id', axis=1) 26 27 # 查看是否有重復數據 28 # data.duplicated().any() 29 30 # 查看數據集的列數 31 # len(data) 32 33 # 刪除重復的記錄 34 data.drop_duplicates(inplace=True) 35 # len(data) 36 37 # 查看各個類別的鳶尾花有多少條記錄 38 data['Species'].value_counts() 39 40 class KNN: 41 '''使用Python語言實現K近鄰算法。(實現分類)''' 42 43 def __init__(self, k): 44 '''初始化方法 45 46 Parameters 47 ------ 48 k : int 49 鄰居的個數 50 51 ''' 52 self.k = k 53 54 def fit(self, X, y): 55 '''訓練方法 56 57 Parameters 58 ------ 59 X : 類數組類型,形狀為:{樣本數量,特征數量} 60 待訓練的樣本特征(屬性) 61 y : 類數組類型,形狀為:{樣本數量} 62 每個樣本的目標值(標簽) 63 64 ''' 65 # 將X轉換為array數組 66 self.X = np.asarray(X) 67 self.y = np.asarray(y) 68 69 def predict(self, X): 70 '''根據參數傳遞的樣本,對樣本數據進行預測。 71 72 Parameters 73 ------- 74 X : 類數組類型,形狀為:[樣本數量,特征數量] 75 待訓陳的樣本特征(屬性) 76 77 Returns 78 ---- 79 result : 數組類型 80 預測的結果 81 ''' 82 83 X = np.asarray(X) 84 result = [] 85 # 對array數組進行遍歷,每次取數組中的一行。 86 for x in X: 87 # 對於測試集中的每一個樣本,依次與訓練集中的所有樣本求距離。 88 dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1)) 89 # 返回數組排序后,每個元素在原數組(排序之前的數組)中的索引 90 index = dis.argsort() 91 # 進行截斷,只取前k個元素。【取距離最近的k個元素的索引】 92 index = index[:self.k] 93 # 返回數組中每個元素出現的次數。元素必須是非負的整數 94 count = np.bincount(self.y[index]) 95 # 返回ndarray數組中值最大的元素對應的索引,該索引就是我們判定的索引 96 # 最大元素索引,就是出現次數最多的元素 97 result.append(count.argmax()) 98 return np.asarray(result) 99 100 101 def predict2(self, X): 102 '''根據參數傳遞的樣本,對樣本數據進行預測。 103 104 Parameters 105 ------- 106 X : 類數組類型,形狀為:[樣本數量,特征數量] 107 待訓陳的樣本特征(屬性) 108 109 Returns 110 ---- 111 result : 數組類型 112 預測的結果 113 ''' 114 115 X = np.asarray(X) 116 result = [] 117 # 對array數組進行遍歷,每次取數組中的一行。 118 for x in X: 119 # 對於測試集中的每一個樣本,依次與訓練集中的所有樣本求距離。 120 dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1)) 121 # 返回數組排序后,每個元素在原數組(排序之前的數組)中的索引 122 index = dis.argsort() 123 # 進行截斷,只取前k個元素。【取距離最近的k個元素的索引】 124 index = index[:self.k] 125 # 返回數組中每個元素出現的次數。元素必須是非負的整數。【使用weights考慮權重,權重為距離的倒數】 126 count = np.bincount(self.y[index], weights=1 / dis[index]) 127 # 返回ndarray數組中值最大的元素對應的索引,該索引就是我們判定的索引 128 # 最大元素索引,就是出現次數最多的元素 129 result.append(count.argmax()) 130 return np.asarray(result) 131 132 # 提取出每個類別的鳶尾花數據 133 t0 = data[data['Species'] == 0] 134 t1 = data[data['Species'] == 1] 135 t2 = data[data['Species'] == 2] 136 # 對每個類別數據進行打亂洗牌 137 t0 = t0.sample(len(t0), random_state=0) 138 t1 = t1.sample(len(t1), random_state=0) 139 t2 = t2.sample(len(t2), random_state=0) 140 # 構建訓練集和測試集 141 train_X = pd.concat([t0.iloc[:40, :-1], t1.iloc[:40, :-1], t2.iloc[:40, :-1]], axis=0) 142 train_y = pd.concat([t0.iloc[:40, -1], t1.iloc[:40, -1], t2.iloc[:40, -1]], axis=0) 143 test_X = pd.concat([t0.iloc[40:, :-1], t1.iloc[40:, :-1], t2.iloc[40:, :-1]], axis=0) 144 test_y = pd.concat([t0.iloc[40:, -1], t1.iloc[40:, -1], t2.iloc[40:, -1]], axis=0) 145 # 創建KNN對象,進行訓練與測試 146 knn = KNN(k=3) 147 # 進行訓練 148 knn.fit(train_X, train_y) 149 # 進行測試,獲得測試的結果 150 result = knn.predict(test_X) 151 152 # 查看顯示 153 # display(result) 154 # display(test_y) 155 156 display(np.sum(result == test_y)) 157 display(np.sum(result == test_y)/ len(result)) 158 159 # 考慮權重,進行一下測試。 160 result2 = knn.predict2(test_X) 161 display(np.sum(result2 == test_y)) 162 163 # 如果想顯示中文的話,可以看這一段,默認情況下,matplotlib不支持中文顯示,進行以下設置 164 # 設置字體為黑體,以支持中文顯示 165 mpl.rcParams['font.family'] = 'SimHei' 166 # 設置在中文字體時,能夠正常的顯示負號(-) 167 mpl.rcParams['axes.unicode_minus'] = False 168 169 # 繪制數據集數據 170 # 設置畫布的大小 171 plt.figure(figsize=(10, 10)) 172 plt.scatter(x=t0['Sepal.Length'][:40], y=t0['Petal.Length'][:40], color='r', label='versicolor') 173 plt.scatter(x=t1['Sepal.Length'][:40], y=t1['Petal.Length'][:40], color='g', label='setosa') 174 plt.scatter(x=t2['Sepal.Length'][:40], y=t2['Petal.Length'][:40], color='b', label='virginica') 175 # 繪制測試集數據 176 right = test_X[result == test_y] 177 wrong = test_X[result != test_y] 178 plt.scatter(x=right['Sepal.Length'], y=right['Petal.Length'], color='c', marker='x', label='right') 179 plt.scatter(x=wrong['Sepal.Length'], y=wrong['Petal.Length'], color='m', marker='>', label='wrong') 180 # 英文顯示title、label 181 plt.xlabel('Sepal.Length') 182 plt.ylabel('Petal.Length') 183 plt.title('KNN classification') 184 plt.legend(loc='best') 185 plt.show()
課程學習鏈接來源:【超實用】Python實現機器學習算法(全)https://www.bilibili.com/video/BV1V7411P7wL