看了原理,總覺得需要用具體問題實現一下機器學習算法的模型,才算學習深刻。而寫此博文的目的是,網上關於K-NN解決此問題的博文很多,但大都是調用Python高級庫實現,尤其不利於初級學習者本人對模型的理解和工程實踐能力的提升,也不利於Python初學者實現該模型。
本博文的特點:
一 全面性地總結K-NN模型的特征、用途
二 基於Python的內置模塊,不調用任何第三方庫實現
博文主要分為四部分:
基本模型(便於理清概念、回顧模型)
對待解決問題的重述
模型(算法)和評價(一來,以便了解模型特點,為以后舉一反三地應用作鋪墊;二來,有利於以后快速復習)、
編程實現(Code)。
特別聲明:
1.勞動成果開源,未經同意博主(千千寰宇:http://cnblogs.com/johnnyzen),不得以任何形式轉載、復制。
2.如有紕漏或者其他看法,歡迎共同探討~
零 基本模型
(本部分內容,均來源於引用[1],其原理講解十分通俗易懂)
①K-近鄰算法,即K-Nearest Neighbor algorithm,簡稱K-NN算法。單從名字來猜想,可以簡單粗暴的認為是:K個最近的鄰居,當K=1時,算法便成了最近鄰算法,即尋找最近的那個鄰居。
②所謂K-NN算法,即是給定一個訓練數據集,對新的輸入實例,在訓練數據集中找到與該實例最鄰近的K個實例(也就是K個鄰居), 這K個實例的多數屬於某個類,就把該輸入實例分類到這個類中。
③實例
猜猜看:有一個未知形狀(綠色圓點),如何判斷其是什么形狀?
問題:給這個綠色的圓分類?
對噪聲數據過於敏感。為了解決這個問題,我們可以把位置樣本周邊的多個最近樣本計算在內,擴大參與決策的樣本量,以避免個別數據直接決定決策結果。
有兩類不同的樣本數據,分別用藍色的小正方形和紅色的小三角形表示,而圖正中間的那個綠色的圓所標示的數據則是待分類的數據。
如果K=3,判定綠色的這個待分類點屬於紅色的三角形一類。
如果K=5,判定綠色的這個待分類點屬於藍色的正方形一類。
一 問題
題目:ML之k-NN:k-NN實現對150朵共三種花的實例的萼片長度、寬,花瓣長、寬數據統計,根據一朵新花的四個特征來預測其種類
數據源:https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
數據源說明:https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.names
二 解決過程及模型評價


由於前面第二部分已經詳細敘述,且代碼中注釋已經十分詳細,便不在對代碼進行解釋,閱讀注釋便容易懂。
__init__.py
import random; import math; import Iris; # 自定義 import file_handle; # 自定義 #if __name__ == '__main__': #__name__ == '__main__'是Python的main函數入口 def main(print_test=False,print_samples=False): follows = []; # 樣本集空間(前sampleAmount項)+測試集空間(后 sampleAll-sampleAmount項) data = ""; # 樣本數據(前sampleAmount項)+測試集空間(后 sampleAll-sampleAmount項) sampleAll = 150; sampleAmount = 100; # 標記樣本集數目(剩余的便作為測試集) k = 5; test_print = print_test; ########## 一 讀取數據集,裝載樣本集 ##### 1.1 加載數據集數據 data = file_handle.read("./dataset/data.txt",1,'r');#1:忽略第一行 # print(data); list = data.split('\n'); i = 0; for line in list: # 如:line = "5.1,3.5,1.4,0.2,Iris-setosa" item = line.split(',');# 如:item = [5.1,3.5,1.4,0.2,Iris-setosa] label_species = item.pop();#移除最后一項:標記種類 #print("[test] item:", item,"\tlabel_species:", label_species); # test follows.append(Iris.Iris(item,label_species)); #print("[ ",i," ] ",follows[i].toString()); i += 1; pass; random.shuffle(follows); # 【千萬注意!!!】由於原數據集是有序的,如果不做亂序處理,預測結果會及其不理想(准確率,趨近於0),當然,這也是這一模型的缺陷之一 ##### 1.2 選擇前100項 作為已標記樣本集 #i = 0; #for i in range(sampleAmount): # follows[i].setPredictSpecies(follows[i].label_species); # pass; ########## 二 訓練測試樣本 ##### 2.1 對101 - 150 項的測試集進行訓練/預測 ## 算法描述: ## 遍歷測試樣本 ## 計算測試樣本與已標記的樣本的歐式距離 ## 對各歐式距離升序排序 ## 選擇前K項的已樣本作為一子集( 即 選擇最近的K項鄰居作為參照標准) ## 遍歷統計,已標記子集的花朵種類何種花朵數目種類最多 ## 設置當前測試樣本的預測花朵種類為該種 ## 結束。 ## 注釋:花的種類分別為:Iris-setosa、Iris-versicolor、Iris-virginica;共計3種。 offset = 0; # 測試空間偏移量:目的是為了將通過偏移量,增大原已標記樣本空間的樣本數量 即 使已預測的測試樣本加入參照樣本空間。 for x in range(sampleAmount,sampleAll):# x:測試樣本下標 weights = [];# 對各歐式距離(權值)的升序排序列表 for y in range(0,sampleAmount+offset): result = (math.sqrt( + \ math.pow(follows[y].features[0] - follows[x].features[0],2) + \ math.pow(follows[y].features[1] - follows[x].features[1],2) + \ math.pow(follows[y].features[2] - follows[x].features[2],2) + \ math.pow(follows[y].features[3] - follows[x].features[3],2)), y);# 存儲x,方便排序后定位花朵 #print("[test] weights[x]:", result); weights.append(result); pass; weights.sort(key = lambda item:item[0]); # 以各元組內第一首項[歐氏距離]為鍵,默認升序排序 if test_print: for m in range(len(weights)): # 輸出預測權重 print("[test] weights[",m,"]:",weights[m],"\tweights[",m,"][1] > ",weights[m][1],":",follows[weights[m][1]].toString()); kinds_count = {"Iris-setosa":0,"Iris-versicolor":0,"Iris-virginica":0}; # 對已標記樣本空間中各種花的數目統計作初始化 for z in range(0,k): # 選擇前K項的已樣本作為一子集( 即 選擇最近的K項鄰居作為參照標准) if test_print: print("[test] 排名前",z+1,"項 follows[",z,"]:",follows[weights[z][1]].toString()); label_species = follows[weights[z][1]].label_species; if(label_species == 'Iris-setosa'): kinds_count["Iris-setosa"] += 1; elif label_species == 'Iris-versicolor': kinds_count["Iris-versicolor"] += 1; elif label_species == 'Iris-virginica': kinds_count['Iris-virginica'] += 1; else: print("[ERROR:Unknown Species] follows[",weight[z][1],"]:",follows[weight[z][1]]); pass; result = max(kinds_count.items(), key = lambda item:item[1]); # 取統計花類數字典中最大值對應的序列 follows[x].predict_species = result[0]; # 標記預測種類 if test_print: print("[test] 預測結果",result, " [follows[",x,"].predict_species]:", follows[x].predict_species); # test offset += 1; #for test in range(len(weights)): # 測試-輸出距離權值結果 # print("[",test,"] weights:",weights[test][0],"\t",follows[weights[test][1]].toString()); # pass; pass; ########## 三 計算預測准確率 rate = 0.0; i = 0; for i in range(sampleAmount,len(follows)): if(follows[i].label_species == follows[i].predict_species): rate += 1; else: print("[預測錯誤樣本] follow[",i,"]:",follows[i].toString()); pass; pass; rate = rate / (sampleAll - sampleAmount); print("預測准確率:",rate); if print_samples: for i in range(0,len(follows)): print(follows[i].toString()); pass; pass; main(False,True);
Iris.py
'Iris module [class] ' __author__ = 'Johnny Zen' class Iris: """ Iris花(類) [Demo] iris = Iris([5.1,3.5,1.4,0.2],"Iris-setosa"); print(iris.toString()); iris.setPredictSpecies('Iris-setosa'); print(iris.toString()); print(iris.label_species); ======================= [features][5.1, 3.5, 1.4, 0.2] [label-species]Iris-setosa [predict-species]None [features][5.1, 3.5, 1.4, 0.2] [label-species]Iris-setosa [predict-species]Iris-setosa Iris-setosa """ features = []; label_species = None; # 標記種類 predict_species = None; # 預測種類 def __init__(self,features,label_species=None): if type(features).__name__ == 'list': self.features = features; else: self.features = list(features); # 此list方法對list對象執行將產生錯誤 pass; for x in range(len(self.features)): # 列表內元素字符串轉實數 self.features[x] = float(self.features[x]); self.label_species = label_species; pass; def setPredictSpecies(self,predict_species=None):#設置預測種類 self.predict_species = predict_species; pass; def toString(self):#與一般函數定義不同,類方法必須包含參數 self[第一個參數] return "[features]" + str(self.features) + "\t[label]" + str(self.label_species + "\t[predict]" + str(self.predict_species)); pass; pass;
file_handle.py
"file_handle module [function]:read(filepath,ignore=0,mode='r')" def read(filepath,ignore=0,mode='r'): try: file = open(filepath,mode); ## file_content = file.read(); file_content = ''; i = 0; for i in range(0,ignore): file.readline(); ##print(i); ##print(file.readline()); for line in file.readlines(): file_content += line; finally: if file: file.close(); #print(file_content); return file_content; pass;
四 參考文獻
[1] K-NN和K-Means算法