KNN分類算法及python代碼實現


KNN分類算法(先驗數據中就有類別之分,未知的數據會被歸類為之前類別中的某一類!

1、KNN介紹

K最近鄰(k-Nearest NeighborKNN分類算法是最簡單的機器學習算法。

機器學習,算法本身不是最難的,最難的是:

1、數學建模:把業務中的特性抽象成向量的過程;

2、選取適合模型的數據樣本。

這兩個事都不是簡單的事。算法反而是比較簡單的事。

本質上,KNN算法就是用距離來衡量樣本之間的相似度。

 

2、算法圖示

◊ 從訓練集中找到和新數據最接近的k條記錄,然后根據多數類來決定新數據類別。

◊算法涉及3個主要因素:

1) 訓練數據集

2) 距離或相似度的計算衡量

3) k的大小

 

◊算法描述

1) 已知兩類“先驗”數據,分別是藍方塊和紅三角,他們分布在一個二維空間中

2) 有一個未知類別的數據(綠點),需要判斷它是屬於“藍方塊”還是“紅三角”類

3) 考察離綠點最近的3個(或k)數據點的類別,占多數的類別即為綠點判定類別

 

3、算法要點

3.1、計算步驟

 1算距離:給定測試對象,計算它與訓練集中的每個對象的距離

 2找鄰居:圈定距離最近的k個訓練對象,作為測試對象的近鄰

 3做分類:根據這k個近鄰歸屬的主要類別,來對測試對象分類

 

3.2、相似度的度量

◊距離越近應該意味着這兩個點屬於一個分類的可能性越大。

但,距離不能代表一切,有些數據的相似度衡量並不適合用距離

◊相似度衡量方法:包括歐式距離夾角余弦等。

(簡單應用中,一般使用歐氏距離,但對於文本分類來說,使用余弦(cosine)來計算相似度就比歐式(Euclidean)距離更合適

 

3.3、類別的判定

簡單投票法:少數服從多數,近鄰中哪個類別的點最多就分為該類。

加權投票法:根據距離的遠近,對近鄰的投票進行加權,距離越近則權重越大(權重為距離平方的倒數)

 

3.4、算法不足

  • 樣本不平衡容易導致結果錯誤

◊如一個類的樣本容量很大,而其他類樣本容量很小時,有可能導致當輸入一個新樣本時,該樣本的K個鄰居中大容量類的樣本占多數。

◊改善方法:對此可以采用權值的方法(和該樣本距離小的鄰居權值大)來改進。

  • 計算量較大

◊因為對每一個待分類的文本都要計算它到全體已知樣本的距離,才能求得它的K個最近鄰點。

◊改善方法:事先對已知樣本點進行剪輯,事先去除對分類作用不大的樣本。

該方法比較適用於樣本容量比較大的類域的分類,而那些樣本容量較小的類域采用這種算法比較容易產生誤分。

 

4、KNN分類算法python實現(python2.7)

需求:

有以下先驗數據,使用knn算法對未知類別數據分類

屬性1

屬性2

類別

1.0

0.9

A

1.0

1.0

A

0.1

0.2

B

0.0

0.1

B

 

未知類別數據

屬性1

屬性2

類別

1.2

1.0

?

0.1

0.3

?

 

python實現:

KNN.py腳本文件

 1 #!/usr/bin/python
 2 # coding=utf-8
 3 #########################################
 4 # kNN: k Nearest Neighbors
 5 
 6 #  輸入:      newInput:  (1xN)的待分類向量
 7 #             dataSet:   (NxM)的訓練數據集
 8 #             labels:     訓練數據集的類別標簽向量
 9 #             k:         近鄰數
10 
11 # 輸出:     可能性最大的分類標簽
12 #########################################
13 
14 from numpy import *
15 import operator
16 
17 # 創建一個數據集,包含2個類別共4個樣本
18 def createDataSet():
19     # 生成一個矩陣,每行表示一個樣本
20     group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
21     # 4個樣本分別所屬的類別
22     labels = ['A', 'A', 'B', 'B']
23     return group, labels
24 
25 # KNN分類算法函數定義
26 def kNNClassify(newInput, dataSet, labels, k):
27     numSamples = dataSet.shape[0]   # shape[0]表示行數
28 
29     # # step 1: 計算距離[
30     # 假如:
31     # Newinput:[1,0,2]
32     # Dataset:
33     # [1,0,1]
34     # [2,1,3]
35     # [1,0,2]
36     # 計算過程即為:
37     # 1、求差
38     # [1,0,1]       [1,0,2]
39     # [2,1,3]   --   [1,0,2]
40     # [1,0,2]       [1,0,2]
41     # =
42     # [0,0,-1]
43     # [1,1,1]
44     # [0,0,-1]
45     # 2、對差值平方
46     # [0,0,1]
47     # [1,1,1]
48     # [0,0,1]
49     # 3、將平方后的差值累加
50     # [1]
51     # [3]
52     # [1]
53     # 4、將上一步驟的值求開方,即得距離
54     # [1]
55     # [1.73]
56     # [1]
57     #
58     # ]
59     # tile(A, reps): 構造一個矩陣,通過A重復reps次得到
60     # the following copy numSamples rows for dataSet
61     diff = tile(newInput, (numSamples, 1)) - dataSet  # 按元素求差值
62     squaredDiff = diff ** 2  # 將差值平方
63     squaredDist = sum(squaredDiff, axis = 1)   # 按行累加
64     distance = squaredDist ** 0.5  # 將差值平方和求開方,即得距離
65 
66     # # step 2: 對距離排序
67     # argsort() 返回排序后的索引值
68     sortedDistIndices = argsort(distance)
69     classCount = {} # define a dictionary (can be append element)
70     for i in xrange(k):
71         # # step 3: 選擇k個最近鄰
72         voteLabel = labels[sortedDistIndices[i]]
73 
74         # # step 4: 計算k個最近鄰中各類別出現的次數
75         # when the key voteLabel is not in dictionary classCount, get()
76         # will return 0
77         classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
78 
79     # # step 5: 返回出現次數最多的類別標簽
80     maxCount = 0
81     for key, value in classCount.items():
82         if value > maxCount:
83             maxCount = value
84             maxIndex = key
85 
86     return maxIndex

 

KNNTest.py測試文件

 1 #!/usr/bin/python
 2 # coding=utf-8
 3 import KNN
 4 from numpy import *
 5 # 生成數據集和類別標簽
 6 dataSet, labels = KNN.createDataSet()
 7 # 定義一個未知類別的數據
 8 testX = array([1.2, 1.0])
 9 k = 3
10 # 調用分類函數對未知數據分類
11 outputLabel = KNN.kNNClassify(testX, dataSet, labels, 3)
12 print "Your input is:", testX, "and classified to class: ", outputLabel
13 
14 testX = array([0.1, 0.3])
15 outputLabel = KNN.kNNClassify(testX, dataSet, labels, 3)
16 print "Your input is:", testX, "and classified to class: ", outputLabel

 

運行結果:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM