k最鄰近算法——使用kNN進行手寫識別


上篇文章中提到了使用pillow對手寫文字進行預處理,本文介紹如何使用kNN算法對文字進行識別。

基本概念

  k最鄰近算法(k-Nearest Neighbor, KNN),是機器學習分類算法中最簡單的一類。假設一個樣本空間被分為幾類,然后給定一個待分類的特征數據,通過計算距離該數據的最近的k個樣本來判斷這個數據屬於哪一類。如果距離待分類屬性最近的k個類大多數都屬於某一個特定的類,那么這個待分類的數據也就屬於這個類。所謂K最近鄰,就是k個最近的鄰居的意思,說的是每個樣本都可以用它最接近的k個鄰居來代表。kNN在確定分類決策上只依據最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別,在決策時,只與極少量的相鄰樣本有關。通常,k是不大於20的整數。

  下圖所示,綠色圓要被決定賦予哪個類,是紅色三角形還是藍色四方形?如果K=3,由於紅色三角形所占比例為2/3,綠色圓將被賦予紅色三角形那個類,如果K=5,由於藍色四方形比例為3/5,因此綠色圓被賦予藍色四方形類。

  在理想情況下,k值選擇1,即只選擇最近的鄰居。在現實生活中往往沒這么理想,比如對於價格來說,有些顧客消息閉塞,可能會為 “最近的鄰居”多付很多錢,所以應當貨比三家,多選擇一些鄰居,取均值來減少噪聲。實際上,k值過大或過小都將影響結果。

計算過程

過程如下:

  1. 計算訓練集中的點與當前點之間的距離;
  2. 按距離降序排序;
  3. 選取與當前點距離最小的k個點;
  4. 如果是數值型數據,計算前k個點的均值;如果是離散數據,計算前k個點所在類別出現的頻率;
  5. 如果是數值型數據,返回前k個點的均值作為預測數值;如果是離散數據,返回前k個點出現頻率最高的類別作為預測分類。

  代碼如下:

 1 from os import  listdir
 2 
 3 #將圖片文件轉換為向量
 4 def img2vector(filename):
 5     with open(filename) as fobj:
 6         arr = fobj.readlines()
 7 
 8     vec, demension = [], len(arr)
 9     for i in range(demension):
10         line = arr[i].strip()
11         for j in range(demension):
12             vec.append(int(line[j]))
13 
14     return vec
15 
16 #讀取訓練數據
17 def createDataset(dir):
18     dataset, labels = [], []
19     files = listdir(dir)
20     for filename in files:
21         label = int(filename[0])
22         labels.append(label)
23         dataset.append(img2vector(dir + '/' + filename))
24 
25     return dataset, labels
26 
27 #計算谷本系數
28 def tanimoto(vec1, vec2):
29     c1, c2, c3 = 0, 0, 0
30     for i in range(len(vec1)):
31         if vec1[i] == 1: c1 += 1
32         if vec2[i] == 1: c2 += 1
33         if vec1[i] == 1 and vec2[i] == 1: c3 += 1
34 
35     return c3 / (c1 + c2 - c3)
36 
37 def classify(dataset, labels, testData, k=20):
38     distances = []
39 
40     for i in range(len(labels)):
41         d = tanimoto(dataset[i], testData)
42         distances.append((d, labels[i]))
43 
44     distances.sort(reverse=True)
45     #key  label,  value   count of the label
46     klabelDict = {}
47     for i in range(k):
48         klabelDict.setdefault(distances[i][1], 0)
49         klabelDict[distances[i][1]] += 1 / k
50 
51     #按value降序排序
52     predDict = sorted(klabelDict.items(), key=lambda item: item[1], reverse=True)
53     return predDic
54 
55 dataset, labels = createDataset('trainingDigits')
56 testData = img2vector('testDigits/8_19.txt')
57 print(classify(dataset, labels, testData))

  我們事先使用pillow對手寫數字進行了二值化處理,形成一個32*32的矩陣,並將每個訓練樣本保存到一個txt文件,文件名以數字開頭,這個數字就是手寫數字的label,如3_1.txt,其中的內容是:

        由於特征值僅由0和1構成,可以將二維的樣本數據保存到一維數組,img2vector完成了數據的轉換。在計算相似度時,使用谷本系數(Tanimoto)計算有限離散集之間的距離,其公式是:兩者重合(相交)的越多,其相似度越高。classify對測試數據進行分類,返回一個包含了預測結果和結果幾率的字典。

 

 加權kNN

  在上述手寫識別的例子中,供使用了900個測試樣本,其中34個產生了誤判,下圖是一個誤判的例子:

  圖中是手寫數字1,程序判斷為7,其原因是代碼所用的方法有可能會選擇很遠的近鄰:

  Y點肉眼去看因為在紅色區域內很容易判斷出多半屬於紅色一類,但因為藍色過多,若K值選取稍大則很容易將其歸為藍色一類。為了改進這一點,可以為每個點的距離增加一個權重,這樣距離近的點可以得到更大的權重。具體的加權方法將會在下一篇文章中介紹。

 


   出處:微信公眾號 "我是8位的"

   本文以學習、研究和分享為主,如需轉載,請聯系本人,標明作者和出處,非商業用途! 

   掃描二維碼關注作者公眾號“我是8位的”


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM