kNN算法:K最近鄰(kNN,k-NearestNeighbor)分類算法


 

一、KNN算法概述

  鄰近算法,或者說K最近鄰(kNN,k-NearestNeighbor)分類算法是數據挖掘分類技術中最簡單的方法之一。所謂K最近鄰,就是k個最近的鄰居的意思,說的是每個樣本都可以用它最接近的k個鄰居來代表。Cover和Hart在1968年提出了最初的鄰近算法。KNN是一種分類(classification)算法,它輸入基於實例的學習(instance-based learning),屬於懶惰學習(lazy learning)即KNN沒有顯式的學習過程,也就是說沒有訓練階段,數據集事先已有了分類和特征值,待收到新樣本后直接進行處理。與急切學習(eager learning)相對應。

  KNN是通過測量不同特征值之間的距離進行分類。 

  思路是:如果一個樣本在特征空間中的k個最鄰近的樣本中的大多數屬於某一個類別,則該樣本也划分為這個類別。KNN算法中,所選擇的鄰居都是已經正確分類的對象。該方法在定類決策上只依據最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別。

  提到KNN,網上最常見的就是下面這個圖,可以幫助大家理解。

  我們要確定綠點屬於哪個顏色(紅色或者藍色),要做的就是選出距離目標點距離最近的k個點,看這k個點的大多數顏色是什么顏色。當k取3的時候,我們可以看出距離最近的三個,分別是紅色、紅色、藍色,因此得到目標點為紅色。

  算法的描述:

  1)計算測試數據與各個訓練數據之間的距離;

  2)按照距離的遞增關系進行排序;

  3)選取距離最小的K個點;

  4)確定前K個點所在類別的出現頻率;

  5)返回前K個點中出現頻率最高的類別作為測試數據的預測分類

 

二、關於K的取值

  K:臨近數,即在預測目標點時取幾個臨近的點來預測。

  K值得選取非常重要,因為:

  如果當K的取值過小時,一旦有噪聲得成分存在們將會對預測產生比較大影響,例如取K值為1時,一旦最近的一個點是噪聲,那么就會出現偏差,K值的減小就意味着整體模型變得復雜,容易發生過擬合;

  如果K的值取的過大時,就相當於用較大鄰域中的訓練實例進行預測,學習的近似誤差會增大。這時與輸入目標點較遠實例也會對預測起作用,使預測發生錯誤。K值的增大就意味着整體的模型變得簡單;

  如果K==N的時候,那么就是取全部的實例,即為取實例中某分類下最多的點,就對預測沒有什么實際的意義了;

  K的取值盡量要取奇數,以保證在計算結果最后會產生一個較多的類別,如果取偶數可能會產生相等的情況,不利於預測。

 

  K的取法:

   常用的方法是從k=1開始,使用檢驗集估計分類器的誤差率。重復該過程,每次K增值1,允許增加一個近鄰。選取產生最小誤差率的K。

  一般k的取值不超過20,上限是n的開方,隨着數據集的增大,K的值也要增大。

 

三、關於距離的選取

  距離就是平面上兩個點的直線距離

  關於距離的度量方法,常用的有:歐幾里得距離、余弦值(cos), 相關度 (correlation), 曼哈頓距離 (Manhattan distance)或其他。

  Euclidean Distance 定義:

  兩個點或元組P1=(x1,y1)和P2=(x2,y2)的歐幾里得距離是

 

 

 

  距離公式為:(多個維度的時候是多個維度各自求差)

 

 

 

四、總結

  KNN算法是最簡單有效的分類算法,簡單且容易實現。當訓練數據集很大時,需要大量的存儲空間,而且需要計算待測樣本和訓練數據集中所有樣本的距離,所以非常耗時

  KNN對於隨機分布的數據集分類效果較差,對於類內間距小,類間間距大的數據集分類效果好,而且對於邊界不規則的數據效果好於線性分類器。

  KNN對於樣本不均衡的數據效果不好,需要進行改進。改進的方法時對k個近鄰數據賦予權重,比如距離測試樣本越近,權重越大。

  KNN很耗時,時間復雜度為O(n),一般適用於樣本數較少的數據集,當數據量大時,可以將數據以樹的形式呈現,能提高速度,常用的有kd-tree和ball-tree。

  (弱小無助。。。根據許多大佬的總結整理的)

 

五、Python實現

  根據算法的步驟,進行kNN的實現,完整代碼如下

 1 #!/usr/bin/env python
 2 # -*- coding:utf-8 -*-
 3 # Author: JYRoooy
 4 import csv
 5 import random
 6 import math
 7 import operator
 8 
 9 # 加載數據集
10 def loadDataset(filename, split, trainingSet = [], testSet = []):
11     with open(filename, 'r') as csvfile:
12         lines = csv.reader(csvfile)
13         dataset = list(lines)
14         for x in range(len(dataset)-1):
15             for y in range(4):
16                 dataset[x][y] = float(dataset[x][y])
17             if random.random() < split:  #將數據集隨機划分
18                 trainingSet.append(dataset[x])
19             else:
20                 testSet.append(dataset[x])
21 
22 # 計算點之間的距離,多維度的
23 def euclideanDistance(instance1, instance2, length):
24     distance = 0
25     for x in range(length):
26         distance += pow((instance1[x]-instance2[x]), 2)
27     return math.sqrt(distance)
28 
29 # 獲取k個鄰居
30 def getNeighbors(trainingSet, testInstance, k):
31     distances = []
32     length = len(testInstance)-1
33     for x in range(len(trainingSet)):
34         dist = euclideanDistance(testInstance, trainingSet[x], length)
35         distances.append((trainingSet[x], dist))   #獲取到測試點到其他點的距離
36     distances.sort(key=operator.itemgetter(1))    #對所有的距離進行排序
37     neighbors = []
38     for x in range(k):   #獲取到距離最近的k個點
39         neighbors.append(distances[x][0])
40         return neighbors
41 
42 # 得到這k個鄰居的分類中最多的那一類
43 def getResponse(neighbors):
44     classVotes = {}
45     for x in range(len(neighbors)):
46         response = neighbors[x][-1]
47         if response in classVotes:
48             classVotes[response] += 1
49         else:
50             classVotes[response] = 1
51     sortedVotes = sorted(classVotes.items(), key=operator.itemgetter(1), reverse=True)
52     return sortedVotes[0][0]   #獲取到票數最多的類別
53 
54 #計算預測的准確率
55 def getAccuracy(testSet, predictions):
56     correct = 0
57     for x in range(len(testSet)):
58         if testSet[x][-1] == predictions[x]:
59             correct += 1
60     return (correct/float(len(testSet)))*100.0
61 
62 
63 def main():
64     #prepare data
65     trainingSet = []
66     testSet = []
67     split = 0.67
68     loadDataset(r'irisdata.txt', split, trainingSet, testSet)
69     print('Trainset: ' + repr(len(trainingSet)))
70     print('Testset: ' + repr(len(testSet)))
71     #generate predictions
72     predictions = []
73     k = 3
74     for x in range(len(testSet)):
75         # trainingsettrainingSet[x]
76         neighbors = getNeighbors(trainingSet, testSet[x], k)
77         result = getResponse(neighbors)
78         predictions.append(result)
79         print ('predicted=' + repr(result) + ', actual=' + repr(testSet[x][-1]))
80     print('predictions: ' + repr(predictions))
81     accuracy = getAccuracy(testSet, predictions)
82     print('Accuracy: ' + repr(accuracy) + '%')
83 
84 if __name__ == '__main__':
85     main()

 

六、sklearn庫的應用

  我利用了sklearn庫來進行了kNN的應用(這個庫是真的很方便了,可以借助這個庫好好學習一下,我是用KNN算法進行了根據成績來預測,這里用一個花瓣萼片的實例,因為這篇主要是關於KNN的知識,所以不對sklearn的過多的分析,而且我用的還不深入😅)

  sklearn庫內的算法與自己手搓的相比功能更強大、拓展性更優異、易用性也更強。還是很受歡迎的。(確實好用,簡單)

1 from sklearn import neighbors   //包含有kNN算法的模塊
2 from sklearn import datasets    //一些數據集的模塊

  調用KNN的分類器

1 knn = neighbors.KNeighborsClassifier()

  預測花瓣代碼

from sklearn import neighbors          
from sklearn import datasets

knn = neighbors.KNeighborsClassifier()

iris = datasets.load_iris()

# f = open("iris.data.csv", 'wb')              #可以保存數據
# f.write(str(iris))
# f.close()

print iris

knn.fit(iris.data, iris.target)                 #用KNN的分類器進行建模,這里利用的默認的參數,大家可以自行查閱文檔

predictedLabel = knn.predict([[0.1, 0.2, 0.3, 0.4]])

print ("predictedLabel is :" + predictedLabel)

 

  上面的例子是只預測了一個,也可以進行數據集的拆分,將數據集划分為訓練集和測試集

from sklearn.mode_selection import train_test_split   #引入數據集拆分的模塊

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

  

  關於 train_test_split 函數參數的說明:

  train_data:被划分的樣本特征集

  train_target:被划分的樣本標簽

  test_size:float-獲得多大比重的測試樣本 (默認:0.25)

       int - 獲得多少個測試樣本

  random_state:是隨機數的種子。

 

 


 

寫在后面

  本人的在學習機器學習和深度學習算法中的源碼 github地址 https://github.com/JYRoy/MachineLearning

  還是在校大學生,知識面不全,也參考了網上許多大佬的博客,一些個人的理解與應用可能有問題,歡迎大家指正,有學到新的相關知識也會對文章進行更新。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM