一、概述
k-近鄰算法采用測量不同特征值之間的距離方法進行分類。
工作原理:首先有一個樣本數據集合(訓練樣本集),並且樣本數據集合中每條數據都存在標簽(分類),即我們知道樣本數據中每一條數據與所屬分類的對應關系,輸入沒有標簽的數據之后,將新數據的每個特征與樣本集的數據對應的特征進行比較(歐式距離運算),然后算出新數據與樣本集中特征最相似(最近鄰)的數據的分類標簽,一般我們選擇樣本數據集中前k個最相似的數據,然后再從k個數據集中選出出現分類最多的分類作為新數據的分類。
二、優缺點
優點:精度高、對異常值不敏感、無數據輸入假定。
缺點:計算度復雜、空間度復雜。
適用范圍:數值型和標稱型
三、數學公式
歐式距離:歐氏距離是最易於理解的一種距離計算方法,源自歐氏空間中兩點間的距離公式。
(1)二維平面上兩點a(x1,y1)與b(x2,y2)間的歐氏距離:
(2)三維空間兩點a(x1,y1,z1)與b(x2,y2,z2)間的歐氏距離:
(3)兩個n維向量a(x11,x12,…,x1n)與 b(x21,x22,…,x2n)間的歐氏距離:
三、算法實現
k-近鄰算法的偽代碼
對未知類型屬性的數據集中的每個點依次執行以下操作:
(1) 計算已知類別數據集中的點與當前點之間的距離;
(2) 按照距離增序排序;
(3) 選取與當前點距離最近的k個點;
(4) 決定這k個點所屬類別的出現頻率;
(5) 返回前k個點出現頻率最高的類別作為當前點的預測分類。
1、構造數據
1 def createDataSet(): 2 group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) 3 labels = ['A','A','B','B'] 4 return group, labels
這里有4組數據,每組數據的列代表不同屬性的特征值,向量labels包含了每個數據點的標簽信息,也可以叫分類。這里有兩類數據,A和B。
2、實施算法
tile:重復某個數組。比如tile(A,n),功能是將數組A重復n次,構成一個新的數組.
1 >>> tile([1,2],(4)) 2 array([1, 2, 1, 2, 1, 2, 1, 2]) 3 >>> tile([1,2],(4,1)) 4 array([[1, 2], 5 [1, 2], 6 [1, 2], 7 [1, 2]]) 8 >>> tile([1,2],(4,2)) 9 array([[1, 2, 1, 2], 10 [1, 2, 1, 2], 11 [1, 2, 1, 2], 12 [1, 2, 1, 2]])
歐式距離算法實現:
1 def classify0(inX, dataSet, labels, k): 2 dataSetSize = dataSet.shape[0] 3 diffMat = tile(inX, (dataSetSize,1)) - dataSet #新數據與樣本數據每一行的值相減 [[x-x1,y-y1],[x-x2,y-y2],[x-x3,y-y3],.....] 4 sqDiffMat = diffMat**2 #數組每一項進行平方[[(x-x1)^2,(y-y1)^2],........] 5 sqDistances = sqDiffMat.sum(axis=1)#數組每個特證求和[[(x-xi)^2+(y-yi)^2],......] 6 distances = sqDistances**0.5 #數組每個值 開根號 ,,歐式距離公式 完成。。。。 7 sortedDistIndicies = distances.argsort() #argsort函數返回的是數組值從小到大的索引值 8 classCount={} #以下是選取 距離最小的前k個值的索引,從k個中選取分類最多的一個作為新數據的分類 9 for i in range(k):# 統計前k個點所屬的類別 10 voteIlabel = labels[sortedDistIndicies[i]] 11 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 12 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 13 return sortedClassCount[0][0]# 返回前k個點中頻率最高的類別
其中 inX:需要分類的新數據,dataSet:樣本數據特征,labels:樣本數據分類,k:選取前k個最近的距離
測試算法:
1 >>> group,labels = kNN.createDataSet() 2 >>> group,labels 3 (array([[ 1. , 1.1], 4 [ 1. , 1. ], 5 [ 0. , 0. ], 6 [ 0. , 0.1]]), ['A', 'A', 'B', 'B']) 7 >>> kNN.classify0([0,0],group,labels,3) 8 'B' 9 >>>
測試結果:[0,0]屬於分類B.
3、如何測試分類器
四、 示例:使用k-近鄰算法改進約會網站的配對效果
我的朋友海倫一直使用在線約會網站尋找適合自己的約會對象。盡管約會網站會推薦不同的人選,但她並不是喜歡每一個人。經過一番總結,她發現曾交往過三種類型的人:
- 不喜歡的人
- 魅力一般的人
- 極具魅力的人
海倫希望我們的分類軟件可以更好地幫助她將匹配對象划分到確切的分類中。此外海倫還收集了一些約會網站未曾記錄的數據信息,她認為這些數據更有助於匹配對象的歸類。
1、准備數據:從文本文件中解析數據
數據存放在文本文件datingTestSet.txt中,每個樣本數據占據一行,總共有1000行。
海倫的樣本主要包含以下3種特征:
- 每年獲得的飛行常客里程數
- 玩視頻游戲所耗時間百分比
- 每周消費的冰淇淋公升數
2、分析數據:使用Matplotlib創建散點圖
散點圖使用datingDataMat矩陣的第一、第二列數據,分別表示特征值“每年獲得的飛行常客里程數”和“玩視頻游戲所耗時間百分比”。
每年贏得的飛行常客里程數與玩視頻游戲所占百分比的約會數據散點圖
3、准備數據:歸一化數值
不同特征值有不同的均值和取值范圍,如果直接使用特征值計算距離,取值范圍較大的特征將對距離計算的結果產生絕對得影響,而使較小的特征值幾乎沒有作用,近乎沒有用到該屬性。如兩組特征:{0, 20000, 1.1}和{67, 32000, 0.1},計算距離的算式為:
顯然第二個特征將對結果產生絕對得影響,第一個特征和第三個特征幾乎不起作用。
然而,對於識別的過程,我們認為這不同特征是同等重要的,因此作為三個等權重的特征之一,飛行常客里程數並不應該如此嚴重地影響到計算結果。
在處理這種不同取值范圍的特征值時,我們通常采用的方法是將數值歸一化,如將取值范圍處理為0到1或者1到1之間。下面的公式可以將任意取值范圍的特征值轉化為0到1區間內的值:
newValue = (oldValue – min) / (max – min)
其中min和max分別是數據集中的最小特征值和最大特征值。
添加autoNorm()函數,用於將數字特征值歸一化:
1 def autoNorm(dataSet): 2 minVals = dataSet.min(0)# 分別求各個特征的最小值 3 maxVals = dataSet.max(0)# 分別求各個特征的最大值 4 ranges = maxVals - minVals# 各個特征的取值范圍 5 normDataSet = zeros(shape(dataSet)) 6 m = dataSet.shape[0] 7 normDataSet = dataSet - tile(minVals, (m,1)) # oldValue - min 8 normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide (oldValue-min)/(max-min) 數據歸一化處理 9 return normDataSet, ranges, minVals
對這個函數,要注意返回結果除了歸一化好的數據,還包括用來歸一化的范圍值ranges和最小值minVals,這將用於對測試數據的歸一化。
注意,對測試數據集的歸一化過程必須使用和訓練數據集相同的參數(ranges和minVals),不能針對測試數據單獨計算ranges和minVals,否則將造成同一組數據在訓練數據集和測試數據集中的不一致。
4、測試算法:作為完整程序驗證分類器
機器學習算法一個很重要的工作就是評估算法的正確率,通常我們只提供已有數據的90%作為訓練樣本來訓練分類器,而使用其余的10%數據去測試分類器,檢測分類器的正確率。需要注意的是,10%的測試數據應該是隨機選擇的。由於海倫提供的數據並沒有按照特定目的來排序,所以我們可以隨意選擇10%數據而不影響其隨機性。
創建分類器針對約會網站的測試代碼:利用樣本集數據進行測試算法
1 def datingClassTest(): 2 hoRatio = 0.50 #hold out 10% 3 datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file 4 normMat, ranges, minVals = autoNorm(datingDataMat) 5 m = normMat.shape[0] 6 numTestVecs = int(m*hoRatio) 7 errorCount = 0.0 8 for i in range(numTestVecs): 9 classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) 10 print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]) 11 if (classifierResult != datingLabels[i]): errorCount += 1.0 12 print "the total error rate is: %f" % (errorCount/float(numTestVecs)) 13 print errorCount
執行分類器測試程序:
1 >>> kNN.datingClassTest() 2 3 the classifier came back with: 2, the real answer is: 1 4 5 the classifier came back with: 2, the real answer is: 2 6 7 the classifier came back with: 1, the real answer is: 1 8 9 the classifier came back with: 1, the real answer is: 1 10 11 the classifier came back with: 2, the real answer is: 2 12 13 ................................................. 14 15 the total error rate is: 0.064000 16 17 32.0
分類器處理約會數據集的錯誤率是6.4%,這是一個相當不錯的結果。我們可以改變函數datingClassTest內變量hoRatio和變量k的值,檢測錯誤率是否隨着變量值的變化而增加。
這個例子表明我們可以正確地預測分類,錯誤率僅僅是2.4%。海倫完全可以輸入未知對象的屬性信息,由分類軟件來幫助她判定某一對象的可交往程度:討厭、一般喜歡、非常喜歡。
5、使用算法:構建完整可用系統
綜合上述代碼,我們可以構建完整的約會網站預測函數:對輸入的數據需要 歸一化處理
1 def classifyPerson(): 2 resultList = ['not at all', 'in small doses', 'in large doses'] 3 percentTats = float(raw_input("Percentage of time spent playing video game?")) 4 ffMiles = float(raw_input("Frequent flier miles earned per year?")) 5 iceCream = float(raw_input("Liters of ice cream consumed per year?")) 6 datingDataMat, datingLabels = file2matrix('datingTestSet.txt') 7 normMat, ranges, minVals = autoNorm(datingDataMat) 8 inArr = array([ffMiles, percentTats, iceCream]) #新數據 需要歸一化處理 9 classifierResult = classify((inArr - minVals) / ranges, normMat, datingLabels, 3) 10 print "You will probably like this person: ", resultList[classifierResult - 1]
目前為止,我們已經看到如何在數據上構建分類器。
完整代碼:

1 ''' 2 Created on Sep 16, 2010 3 kNN: k Nearest Neighbors 4 5 Input: inX: vector to compare to existing dataset (1xN) 6 dataSet: size m data set of known vectors (NxM) 7 labels: data set labels (1xM vector) 8 k: number of neighbors to use for comparison (should be an odd number) 9 10 Output: the most popular class label 11 12 @author: pbharrin 13 ''' 14 from numpy import * 15 import operator 16 from os import listdir 17 import matplotlib 18 import matplotlib.pyplot as plt 19 def show(d,l): 20 #d,l=kNN.file2matrix('datingTestSet2.txt') 21 fig=plt.figure() 22 ax=fig.add_subplot(111) 23 ax.scatter(d[:,0],d[:,1],15*array(l),15*array(l)) 24 plt.show() 25 def show2(): 26 datingDataMat,datingLabels=file2matrix('datingTestSet2.txt') 27 fig = plt.figure() 28 ax = fig.add_subplot(111) 29 l=datingDataMat.shape[0] 30 X1=[] 31 Y1=[] 32 X2=[] 33 Y2=[] 34 X3=[] 35 Y3=[] 36 for i in range(l): 37 if datingLabels[i]==1: 38 X1.append(datingDataMat[i,0]);Y1.append(datingDataMat[i,1]) 39 elif datingLabels[i]==2: 40 X2.append(datingDataMat[i,0]);Y2.append(datingDataMat[i,1]) 41 else: 42 X3.append(datingDataMat[i,0]);Y3.append(datingDataMat[i,1]) 43 type1=ax.scatter(X1,Y1,c='red') 44 type2=ax.scatter(X2,Y2,c='green') 45 type3=ax.scatter(X3,Y3,c='blue') 46 #ax.axis([-2,25,-0.2,2.0]) 47 ax.legend([type1, type2, type3], ["Did Not Like", "Liked in Small Doses", "Liked in Large Doses"], loc=2) 48 plt.xlabel('Percentage of Time Spent Playing Video Games') 49 plt.ylabel('Liters of Ice Cream Consumed Per Week') 50 plt.show() 51 52 def classify0(inX, dataSet, labels, k): 53 dataSetSize = dataSet.shape[0] 54 diffMat = tile(inX, (dataSetSize,1)) - dataSet 55 sqDiffMat = diffMat**2 56 sqDistances = sqDiffMat.sum(axis=1) 57 distances = sqDistances**0.5 58 sortedDistIndicies = distances.argsort() 59 classCount={} 60 for i in range(k): 61 voteIlabel = labels[sortedDistIndicies[i]] 62 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 63 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 64 return sortedClassCount[0][0] 65 66 def createDataSet(): 67 group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) 68 labels = ['A','A','B','B'] 69 return group, labels 70 71 def file2matrix(filename): 72 fr = open(filename) 73 numberOfLines = len(fr.readlines()) #get the number of lines in the file 74 returnMat = zeros((numberOfLines,3)) #prepare matrix to return 75 classLabelVector = [] #prepare labels return 76 fr = open(filename) 77 index = 0 78 for line in fr.readlines(): 79 line = line.strip() 80 listFromLine = line.split('\t') 81 returnMat[index,:] = listFromLine[0:3] 82 classLabelVector.append(int(listFromLine[-1])) 83 index += 1 84 return returnMat,classLabelVector 85 86 def autoNorm(dataSet): 87 minVals = dataSet.min(0) 88 maxVals = dataSet.max(0) 89 ranges = maxVals - minVals 90 normDataSet = zeros(shape(dataSet)) 91 m = dataSet.shape[0] 92 normDataSet = dataSet - tile(minVals, (m,1)) 93 normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide 94 return normDataSet, ranges, minVals 95 96 def datingClassTest(): 97 hoRatio = 0.50 #hold out 10% 98 datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file 99 normMat, ranges, minVals = autoNorm(datingDataMat) 100 m = normMat.shape[0] 101 numTestVecs = int(m*hoRatio) 102 errorCount = 0.0 103 for i in range(numTestVecs): 104 classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) 105 print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]) 106 if (classifierResult != datingLabels[i]): errorCount += 1.0 107 print "the total error rate is: %f" % (errorCount/float(numTestVecs)) 108 print errorCount 109 110 def img2vector(filename): 111 returnVect = zeros((1,1024)) 112 fr = open(filename) 113 for i in range(32): 114 lineStr = fr.readline() 115 for j in range(32): 116 returnVect[0,32*i+j] = int(lineStr[j]) 117 return returnVect 118 119 def handwritingClassTest(): 120 hwLabels = [] 121 trainingFileList = listdir('trainingDigits') #load the training set 122 m = len(trainingFileList) 123 trainingMat = zeros((m,1024)) 124 for i in range(m): 125 fileNameStr = trainingFileList[i] 126 fileStr = fileNameStr.split('.')[0] #take off .txt 127 classNumStr = int(fileStr.split('_')[0]) 128 hwLabels.append(classNumStr) 129 trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr) 130 testFileList = listdir('testDigits') #iterate through the test set 131 errorCount = 0.0 132 mTest = len(testFileList) 133 for i in range(mTest): 134 fileNameStr = testFileList[i] 135 fileStr = fileNameStr.split('.')[0] #take off .txt 136 classNumStr = int(fileStr.split('_')[0]) 137 vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) 138 classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) 139 print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr) 140 if (classifierResult != classNumStr): errorCount += 1.0 141 print "\nthe total number of errors is: %d" % errorCount 142 print "\nthe total error rate is: %f" % (errorCount/float(mTest))