一、 K鄰近算法思想:存在一個樣本數據集合,稱為訓練樣本集,並且每個數據都存在標簽,即我們知道樣本集中每一數據(這里的數據是一組數據,可以是n維向量)與所屬分類的對應關系。輸入沒有標簽的新數據后,將新數據的每個特征(向量的每個元素)與樣本集中數據對應的特征進行比較,然后算法提取樣本集中特征最相似的的分類標簽。由於樣本集可以很大,我們選取前k個最相似數據,然后統計k個數據中出現頻率最高的標簽為新數據的標簽。
K鄰近算法的一般流程:
(1)收集數據:可以是本地數據,也可以從網頁抓取。
(2)准備數據:將數據結構化,方便操作。
(3)分析數據:可以使用任何方法。
(4)訓練算法:此步驟不適用於k鄰近算法。
(5)測試算法:計算錯誤率;計算公式:錯誤率=測試出錯次數/總測試次數
(6)使用算法:輸入樣本數據,輸出結構化的結果,判斷新數據屬於哪個分類。
二、使用K近鄰算法的一個例子
我使用的是spyder的開發環境,python的版本是3.5,spyder自帶了numpy函數庫。新建一個KNN.py文件,在本文件中完成本章實驗。
在KNN中寫一個數據生成函數:
1 from numpy import * 2 import operator 3 4 def createDataset(): 5 group = array([[1.0,1.1],[1.0,1.0],[0.0,0.0],[0.0,0.1]]) 6 labels = ['A','A','B','B'] 7 return group,labels
在spyder中輸入 :
>>> import KNN
>>>group,labels = KNN.createDataSet()
>>>group
array([[ 1. , 1.1],
[ 1. , 1. ],
[ 0. , 0. ],
[ 0. , 0.1]])
>>>labels
['A', 'A', 'B', 'B']
出現以上提示則說明函數正確。
三、K近鄰算法函數
1 def classify(inX,dataset,labels,k): 2 dataSetSize = dataset.shape[0] 3 diffMat = tile(inX,(dataSetSize,1))-dataset 4 sqDiffMat = diffMat**2 5 sqDistances = sqDiffMat.sum(axis=1) 6 distances = sqDistances**0.5 7 sortedDistIndicies = distances.argsort() 8 classCount ={} 9 for i in range(k): 10 voteIlabel = labels[sortedDistIndicies[i]] 11 classCount[voteIlabel] = classCount.get(voteIlabel,0)+1 12 sortedClassCount = sorted(classCount.items(), 13 key=operator.itemgetter(1),reverse=True) 14 return sortedClassCount[0][0]
驗證:在spyder中輸入
>>> KNN.classify([0,0],group,labels,3)
輸出結果應該為'B'。
四、例子:約會網站匹配改進
海倫收集約會數據已經有一段時間,她把這些數據放在文本文件datingdata.txt中,每個樣本數據占據一行,共有1000行(她可能約會過1000個人,太可怕了^_^),每個樣本主要包括以下3中特征:
1、每年獲得的飛行常客里程數
2、玩視頻游戲所耗的時間百分數
3、每周消費的冰激凌公升數
上述數據保存在文本文件中,數據之間以空格間隔,在數據輸入分類器之前,必須將待處理數據改變為分類器可以處理的數據,在KNN中創建名為file2matrix的函數,進行數據處理。
1 def file2matrix(filename): 2 fr = open(filename,'r') 3 arrayOLines = fr.readlines() 4 numberOfLines = len(arrayOLines) 5 returnMat = zeros((numberOfLines,3)) 6 classLabelVector = [] 7 index = 0 8 for line in arrayOLines: 9 line = line.strip() 10 listFromLine = line.split('\t') 11 returnMat[index,:] = listFromLine[0:3] 12 classLabelVector.append(int(listFromLine[-1])) 13 index += 1 14 return returnMat,classLabelVector 15 retarnmat,classlabelvector = file2matrix('datingdata.txt')
在我運行這段程序,總是出現錯誤提示:could not convert string to float: '12 34 56',對於這個問題,我的改法是將文本中數據間的空格改為','並將
listFromLine = line.split('\t')改為
listFromLine = line.split(',')
這樣就可以解決問題,但是不是最好的方法,還需要改進。