鳶尾花數據分類,通過Python實現KNN分類算法。


鳶尾花數據分類,通過Python實現KNN分類算法。

項目來源:https://aistudio.baidu.com/aistudio/projectdetail/1988428

數據集來源:鳶尾花數據集https://aistudio.baidu.com/aistudio/datasetdetail/91206

  1 import numpy as np 
  2 import pandas as pd
  3 import matplotlib as mpl 
  4 import matplotlib.pyplot as plt 
  5 
  6 # 讀取鳶尾花數據集,header參數來指定標題的行。默認為0。如果沒有標題,則使用None
  7 # 四個特征分別為花萼長度sepal length,花萼寬度sepal width,花瓣長度petal length,花瓣寬度petal width。鳶尾花的種類,共有3種,分別為山鳶尾Iris Setosa、雜色鳶尾Iris Versicolour、維吉尼亞鳶尾Iris Virginica。
  8 data = pd.read_csv('./data/data91206/iris.csv', header=0)
  9 # 顯示全部數據
 10 # data
 11 
 12 # 顯示前n行的數據,默認n的值為5
 13 # data.head()
 14 
 15 # 顯示末尾的n行記錄,默認n的值為5
 16 # data.tail()
 17 
 18 # 隨機抽取樣本,默認抽取一條,我們可以通過修改參數來指定抽取樣本的數量
 19 data.sample(10)
 20 
 21 # 將類別文本映射為數值類型
 22 data['Species'] = data['Species'].map({'versicolor':0,'setosa':1,'virginica':2})
 23 # 刪除不需要的Id列,並改變原來的文本,以下有兩種方法
 24 # data.drop('Id', axis=1, inplace = True)
 25 data = data.drop('Id', axis=1)
 26 
 27 # 查看是否有重復數據
 28 # data.duplicated().any()
 29 
 30 # 查看數據集的列數
 31 # len(data)
 32 
 33 # 刪除重復的記錄
 34 data.drop_duplicates(inplace=True)
 35 # len(data)
 36 
 37 # 查看各個類別的鳶尾花有多少條記錄
 38 data['Species'].value_counts()
 39 
 40 class KNN:
 41     '''使用Python語言實現K近鄰算法。(實現分類)'''
 42     
 43     def __init__(self, k):
 44         '''初始化方法
 45         
 46         Parameters
 47         ------
 48         k : int
 49            鄰居的個數
 50            
 51         '''
 52         self.k = k
 53         
 54     def fit(self, X, y):
 55         '''訓練方法
 56         
 57         Parameters
 58         ------
 59         X : 類數組類型,形狀為:{樣本數量,特征數量}
 60             待訓練的樣本特征(屬性)
 61         y : 類數組類型,形狀為:{樣本數量}
 62             每個樣本的目標值(標簽)
 63         
 64         '''
 65         # 將X轉換為array數組
 66         self.X = np.asarray(X)
 67         self.y = np.asarray(y)
 68         
 69     def predict(self, X):
 70         '''根據參數傳遞的樣本,對樣本數據進行預測。
 71         
 72         Parameters
 73         -------
 74         X : 類數組類型,形狀為:[樣本數量,特征數量]
 75             待訓陳的樣本特征(屬性)
 76                        
 77         Returns
 78         ----
 79         result : 數組類型
 80                 預測的結果
 81         '''
 82         
 83         X = np.asarray(X)
 84         result = []
 85         # 對array數組進行遍歷,每次取數組中的一行。
 86         for x in X:
 87             # 對於測試集中的每一個樣本,依次與訓練集中的所有樣本求距離。
 88             dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1))
 89             # 返回數組排序后,每個元素在原數組(排序之前的數組)中的索引
 90             index = dis.argsort()
 91             # 進行截斷,只取前k個元素。【取距離最近的k個元素的索引】
 92             index = index[:self.k]
 93             # 返回數組中每個元素出現的次數。元素必須是非負的整數
 94             count = np.bincount(self.y[index])
 95             # 返回ndarray數組中值最大的元素對應的索引,該索引就是我們判定的索引
 96             # 最大元素索引,就是出現次數最多的元素
 97             result.append(count.argmax())
 98         return np.asarray(result)
 99     
100     
101     def predict2(self, X):
102         '''根據參數傳遞的樣本,對樣本數據進行預測。
103         
104         Parameters
105         -------
106         X : 類數組類型,形狀為:[樣本數量,特征數量]
107             待訓陳的樣本特征(屬性)
108             
109         Returns
110         ----
111         result : 數組類型
112                 預測的結果
113         '''
114 
115         X = np.asarray(X)
116         result = []
117         # 對array數組進行遍歷,每次取數組中的一行。
118         for x in X:
119             # 對於測試集中的每一個樣本,依次與訓練集中的所有樣本求距離。
120             dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1))
121             # 返回數組排序后,每個元素在原數組(排序之前的數組)中的索引
122             index = dis.argsort()
123             # 進行截斷,只取前k個元素。【取距離最近的k個元素的索引】
124             index = index[:self.k]
125             # 返回數組中每個元素出現的次數。元素必須是非負的整數。【使用weights考慮權重,權重為距離的倒數】      
126             count = np.bincount(self.y[index], weights=1 / dis[index])
127             # 返回ndarray數組中值最大的元素對應的索引,該索引就是我們判定的索引
128             # 最大元素索引,就是出現次數最多的元素
129             result.append(count.argmax())
130         return np.asarray(result)
131 
132 # 提取出每個類別的鳶尾花數據
133 t0 = data[data['Species'] == 0]
134 t1 = data[data['Species'] == 1]
135 t2 = data[data['Species'] == 2]
136 # 對每個類別數據進行打亂洗牌
137 t0 = t0.sample(len(t0), random_state=0)
138 t1 = t1.sample(len(t1), random_state=0)
139 t2 = t2.sample(len(t2), random_state=0)
140 # 構建訓練集和測試集
141 train_X = pd.concat([t0.iloc[:40, :-1], t1.iloc[:40, :-1], t2.iloc[:40, :-1]], axis=0)
142 train_y = pd.concat([t0.iloc[:40, -1], t1.iloc[:40, -1], t2.iloc[:40, -1]], axis=0)
143 test_X = pd.concat([t0.iloc[40:, :-1], t1.iloc[40:, :-1], t2.iloc[40:, :-1]], axis=0)
144 test_y = pd.concat([t0.iloc[40:, -1], t1.iloc[40:, -1], t2.iloc[40:, -1]], axis=0)
145 # 創建KNN對象,進行訓練與測試
146 knn = KNN(k=3)
147 # 進行訓練
148 knn.fit(train_X, train_y)
149 # 進行測試,獲得測試的結果
150 result = knn.predict(test_X)
151 
152 # 查看顯示
153 # display(result)
154 # display(test_y)
155 
156 display(np.sum(result == test_y))
157 display(np.sum(result == test_y)/ len(result))
158 
159 # 考慮權重,進行一下測試。
160 result2 = knn.predict2(test_X)
161 display(np.sum(result2 == test_y))
162 
163 # 如果想顯示中文的話,可以看這一段,默認情況下,matplotlib不支持中文顯示,進行以下設置
164 # 設置字體為黑體,以支持中文顯示
165 mpl.rcParams['font.family'] = 'SimHei'
166 # 設置在中文字體時,能夠正常的顯示負號(-)
167 mpl.rcParams['axes.unicode_minus'] = False
168 
169 # 繪制數據集數據
170 # 設置畫布的大小
171 plt.figure(figsize=(10, 10))
172 plt.scatter(x=t0['Sepal.Length'][:40], y=t0['Petal.Length'][:40], color='r', label='versicolor')  
173 plt.scatter(x=t1['Sepal.Length'][:40], y=t1['Petal.Length'][:40], color='g', label='setosa')  
174 plt.scatter(x=t2['Sepal.Length'][:40], y=t2['Petal.Length'][:40], color='b', label='virginica')  
175 # 繪制測試集數據
176 right = test_X[result == test_y]
177 wrong = test_X[result != test_y]
178 plt.scatter(x=right['Sepal.Length'], y=right['Petal.Length'], color='c', marker='x', label='right')  
179 plt.scatter(x=wrong['Sepal.Length'], y=wrong['Petal.Length'], color='m', marker='>', label='wrong')  
180 # 英文顯示title、label
181 plt.xlabel('Sepal.Length')
182 plt.ylabel('Petal.Length')
183 plt.title('KNN classification')
184 plt.legend(loc='best')
185 plt.show()

課程學習鏈接來源:【超實用】Python實現機器學習算法(全)https://www.bilibili.com/video/BV1V7411P7wL


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM