鸢尾花数据分类,通过Python实现KNN分类算法。
项目来源:https://aistudio.baidu.com/aistudio/projectdetail/1988428
数据集来源:鸢尾花数据集https://aistudio.baidu.com/aistudio/datasetdetail/91206
1 import numpy as np 2 import pandas as pd 3 import matplotlib as mpl 4 import matplotlib.pyplot as plt 5 6 # 读取鸢尾花数据集,header参数来指定标题的行。默认为0。如果没有标题,则使用None 7 # 四个特征分别为花萼长度sepal length,花萼宽度sepal width,花瓣长度petal length,花瓣宽度petal width。鸢尾花的种类,共有3种,分别为山鸢尾Iris Setosa、杂色鸢尾Iris Versicolour、维吉尼亚鸢尾Iris Virginica。 8 data = pd.read_csv('./data/data91206/iris.csv', header=0) 9 # 显示全部数据 10 # data 11 12 # 显示前n行的数据,默认n的值为5 13 # data.head() 14 15 # 显示末尾的n行记录,默认n的值为5 16 # data.tail() 17 18 # 随机抽取样本,默认抽取一条,我们可以通过修改参数来指定抽取样本的数量 19 data.sample(10) 20 21 # 将类别文本映射为数值类型 22 data['Species'] = data['Species'].map({'versicolor':0,'setosa':1,'virginica':2}) 23 # 删除不需要的Id列,并改变原来的文本,以下有两种方法 24 # data.drop('Id', axis=1, inplace = True) 25 data = data.drop('Id', axis=1) 26 27 # 查看是否有重复数据 28 # data.duplicated().any() 29 30 # 查看数据集的列数 31 # len(data) 32 33 # 删除重复的记录 34 data.drop_duplicates(inplace=True) 35 # len(data) 36 37 # 查看各个类别的鸢尾花有多少条记录 38 data['Species'].value_counts() 39 40 class KNN: 41 '''使用Python语言实现K近邻算法。(实现分类)''' 42 43 def __init__(self, k): 44 '''初始化方法 45 46 Parameters 47 ------ 48 k : int 49 邻居的个数 50 51 ''' 52 self.k = k 53 54 def fit(self, X, y): 55 '''训练方法 56 57 Parameters 58 ------ 59 X : 类数组类型,形状为:{样本数量,特征数量} 60 待训练的样本特征(属性) 61 y : 类数组类型,形状为:{样本数量} 62 每个样本的目标值(标签) 63 64 ''' 65 # 将X转换为array数组 66 self.X = np.asarray(X) 67 self.y = np.asarray(y) 68 69 def predict(self, X): 70 '''根据参数传递的样本,对样本数据进行预测。 71 72 Parameters 73 ------- 74 X : 类数组类型,形状为:[样本数量,特征数量] 75 待训陈的样本特征(属性) 76 77 Returns 78 ---- 79 result : 数组类型 80 预测的结果 81 ''' 82 83 X = np.asarray(X) 84 result = [] 85 # 对array数组进行遍历,每次取数组中的一行。 86 for x in X: 87 # 对于测试集中的每一个样本,依次与训练集中的所有样本求距离。 88 dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1)) 89 # 返回数组排序后,每个元素在原数组(排序之前的数组)中的索引 90 index = dis.argsort() 91 # 进行截断,只取前k个元素。【取距离最近的k个元素的索引】 92 index = index[:self.k] 93 # 返回数组中每个元素出现的次数。元素必须是非负的整数 94 count = np.bincount(self.y[index]) 95 # 返回ndarray数组中值最大的元素对应的索引,该索引就是我们判定的索引 96 # 最大元素索引,就是出现次数最多的元素 97 result.append(count.argmax()) 98 return np.asarray(result) 99 100 101 def predict2(self, X): 102 '''根据参数传递的样本,对样本数据进行预测。 103 104 Parameters 105 ------- 106 X : 类数组类型,形状为:[样本数量,特征数量] 107 待训陈的样本特征(属性) 108 109 Returns 110 ---- 111 result : 数组类型 112 预测的结果 113 ''' 114 115 X = np.asarray(X) 116 result = [] 117 # 对array数组进行遍历,每次取数组中的一行。 118 for x in X: 119 # 对于测试集中的每一个样本,依次与训练集中的所有样本求距离。 120 dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1)) 121 # 返回数组排序后,每个元素在原数组(排序之前的数组)中的索引 122 index = dis.argsort() 123 # 进行截断,只取前k个元素。【取距离最近的k个元素的索引】 124 index = index[:self.k] 125 # 返回数组中每个元素出现的次数。元素必须是非负的整数。【使用weights考虑权重,权重为距离的倒数】 126 count = np.bincount(self.y[index], weights=1 / dis[index]) 127 # 返回ndarray数组中值最大的元素对应的索引,该索引就是我们判定的索引 128 # 最大元素索引,就是出现次数最多的元素 129 result.append(count.argmax()) 130 return np.asarray(result) 131 132 # 提取出每个类别的鸢尾花数据 133 t0 = data[data['Species'] == 0] 134 t1 = data[data['Species'] == 1] 135 t2 = data[data['Species'] == 2] 136 # 对每个类别数据进行打乱洗牌 137 t0 = t0.sample(len(t0), random_state=0) 138 t1 = t1.sample(len(t1), random_state=0) 139 t2 = t2.sample(len(t2), random_state=0) 140 # 构建训练集和测试集 141 train_X = pd.concat([t0.iloc[:40, :-1], t1.iloc[:40, :-1], t2.iloc[:40, :-1]], axis=0) 142 train_y = pd.concat([t0.iloc[:40, -1], t1.iloc[:40, -1], t2.iloc[:40, -1]], axis=0) 143 test_X = pd.concat([t0.iloc[40:, :-1], t1.iloc[40:, :-1], t2.iloc[40:, :-1]], axis=0) 144 test_y = pd.concat([t0.iloc[40:, -1], t1.iloc[40:, -1], t2.iloc[40:, -1]], axis=0) 145 # 创建KNN对象,进行训练与测试 146 knn = KNN(k=3) 147 # 进行训练 148 knn.fit(train_X, train_y) 149 # 进行测试,获得测试的结果 150 result = knn.predict(test_X) 151 152 # 查看显示 153 # display(result) 154 # display(test_y) 155 156 display(np.sum(result == test_y)) 157 display(np.sum(result == test_y)/ len(result)) 158 159 # 考虑权重,进行一下测试。 160 result2 = knn.predict2(test_X) 161 display(np.sum(result2 == test_y)) 162 163 # 如果想显示中文的话,可以看这一段,默认情况下,matplotlib不支持中文显示,进行以下设置 164 # 设置字体为黑体,以支持中文显示 165 mpl.rcParams['font.family'] = 'SimHei' 166 # 设置在中文字体时,能够正常的显示负号(-) 167 mpl.rcParams['axes.unicode_minus'] = False 168 169 # 绘制数据集数据 170 # 设置画布的大小 171 plt.figure(figsize=(10, 10)) 172 plt.scatter(x=t0['Sepal.Length'][:40], y=t0['Petal.Length'][:40], color='r', label='versicolor') 173 plt.scatter(x=t1['Sepal.Length'][:40], y=t1['Petal.Length'][:40], color='g', label='setosa') 174 plt.scatter(x=t2['Sepal.Length'][:40], y=t2['Petal.Length'][:40], color='b', label='virginica') 175 # 绘制测试集数据 176 right = test_X[result == test_y] 177 wrong = test_X[result != test_y] 178 plt.scatter(x=right['Sepal.Length'], y=right['Petal.Length'], color='c', marker='x', label='right') 179 plt.scatter(x=wrong['Sepal.Length'], y=wrong['Petal.Length'], color='m', marker='>', label='wrong') 180 # 英文显示title、label 181 plt.xlabel('Sepal.Length') 182 plt.ylabel('Petal.Length') 183 plt.title('KNN classification') 184 plt.legend(loc='best') 185 plt.show()
课程学习链接来源:【超实用】Python实现机器学习算法(全)https://www.bilibili.com/video/BV1V7411P7wL