KNN分類算法實現手寫數字識別


需求:

利用一個手寫數字“先驗數據”集,使用knn算法來實現對手寫數字的自動識別;

先驗數據(訓練數據)集:

♦數據維度比較大,樣本數比較多。

♦ 數據集包括數字0-9的手寫體。

♦每個數字大約有200個樣本。

♦每個樣本保持在一個txt文件中。

♦手寫體圖像本身的大小是32x32的二值圖,轉換到txt文件保存后,內容也是32x32個數字,0或者1,如下:

數據集壓縮包解壓后有兩個目錄:(將這兩個目錄文件夾拷貝的項目路徑下E:/KNNCase/digits/

♦目錄trainingDigits存放的是大約2000個訓練數據

♦目錄testDigits存放大約900個測試數據。

 

模型分析:

1、手寫體因為每個人,甚至每次寫的字都不會完全精確一致,所以,識別手寫體的關鍵是“相似度”

2、既然是要求樣本之間的相似度,那么,首先需要將樣本進行抽象,將每個樣本變成一系列特征數據(即特征向量)

3、手寫體在直觀上就是一個個的圖片,而圖片是由上述圖示中的像素點來描述的,樣本的相似度其實就是像素的位置和顏色之間的組合的相似度

4、因此,將圖片的像素按照固定順序讀取到一個個的向量中,即可很好地表示手寫體樣本

5、抽象出了樣本向量,及相似度計算模型,即可應用KNN來實現

 

python實現:

新建一個KNN.py腳本文件,文件里面包含四個函數:

1) 一個用來生成將每個樣本的txt文件轉換為對應的一個向量,

2) 一個用來加載整個數據集,

3) 一個實現kNN分類算法。

4) 最后就是實現加載、測試的函數。

 
  1 #!/usr/bin/python
  2 # coding=utf-8
  3 #########################################
  4 # kNN: k Nearest Neighbors
  5 
  6 # 參數:        inX: vector to compare to existing dataset (1xN)
  7 #             dataSet: size m data set of known vectors (NxM)
  8 #             labels: data set labels (1xM vector)
  9 #             k: number of neighbors to use for comparison
 10 
 11 # 輸出:     多數類
 12 #########################################
 13 
 14 from numpy import *
 15 import operator
 16 import os
 17 
 18 
 19 # KNN分類核心方法
 20 def kNNClassify(newInput, dataSet, labels, k):
 21     numSamples = dataSet.shape[0]  # shape[0]代表行數
 22 
 23     # # step 1: 計算歐式距離
 24     # tile(A, reps): 將A重復reps次來構造一個矩陣
 25     # the following copy numSamples rows for dataSet
 26     diff = tile(newInput, (numSamples, 1)) - dataSet  # Subtract element-wise
 27     squaredDiff = diff ** 2  # squared for the subtract
 28     squaredDist = sum(squaredDiff, axis = 1)   # sum is performed by row
 29     distance = squaredDist ** 0.5
 30 
 31     # # step 2: 對距離排序
 32     # argsort()返回排序后的索引
 33     sortedDistIndices = argsort(distance)
 34 
 35     classCount = {}  # 定義一個空的字典
 36     for i in xrange(k):
 37         # # step 3: 選擇k個最小距離
 38         voteLabel = labels[sortedDistIndices[i]]
 39 
 40         # # step 4: 計算類別的出現次數
 41         # when the key voteLabel is not in dictionary classCount, get()
 42         # will return 0
 43         classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
 44 
 45     # # step 5: 返回出現次數最多的類別作為分類結果
 46     maxCount = 0
 47     for key, value in classCount.items():
 48         if value > maxCount:
 49             maxCount = value
 50             maxIndex = key
 51 
 52     return maxIndex
 53 
 54 # 將圖片轉換為向量
 55 def  img2vector(filename):
 56     rows = 32
 57     cols = 32
 58     imgVector = zeros((1, rows * cols))
 59     fileIn = open(filename)
 60     for row in xrange(rows):
 61         lineStr = fileIn.readline()
 62         for col in xrange(cols):
 63             imgVector[0, row * 32 + col] = int(lineStr[col])
 64 
 65     return imgVector
 66 
 67 # 加載數據集
 68 def loadDataSet():
 69     # # step 1: 讀取訓練數據集
 70     print "---Getting training set..."
 71     dataSetDir = 'E:/KNNCase/digits/'
 72     trainingFileList = os.listdir(dataSetDir + 'trainingDigits')  # 加載測試數據
 73     numSamples = len(trainingFileList)
 74 
 75     train_x = zeros((numSamples, 1024))
 76     train_y = []
 77     for i in xrange(numSamples):
 78         filename = trainingFileList[i]
 79 
 80         # get train_x
 81         train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' % filename)
 82 
 83         # get label from file name such as "1_18.txt"
 84         label = int(filename.split('_')[0]) # return 1
 85         train_y.append(label)
 86 
 87     # # step 2:讀取測試數據集
 88     print "---Getting testing set..."
 89     testingFileList = os.listdir(dataSetDir + 'testDigits') # load the testing set
 90     numSamples = len(testingFileList)
 91     test_x = zeros((numSamples, 1024))
 92     test_y = []
 93     for i in xrange(numSamples):
 94         filename = testingFileList[i]
 95 
 96         # get train_x
 97         test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' % filename)
 98 
 99         # get label from file name such as "1_18.txt"
100         label = int(filename.split('_')[0]) # return 1
101         test_y.append(label)
102 
103     return train_x, train_y, test_x, test_y
104 
105 # 手寫識別主流程
106 def testHandWritingClass():
107     # # step 1: 加載數據
108     print "step 1: load data..."
109     train_x, train_y, test_x, test_y = loadDataSet()
110 
111     # # step 2: 模型訓練.
112     print "step 2: training..."
113     pass
114 
115     # # step 3: 測試
116     print "step 3: testing..."
117     numTestSamples = test_x.shape[0]
118     matchCount = 0
119     for i in xrange(numTestSamples):
120         predict = kNNClassify(test_x[i], train_x, train_y, 3)
121         if predict == test_y[i]:
122             matchCount += 1
123     accuracy = float(matchCount) / numTestSamples
124 
125     # # step 4: 輸出結果
126     print "step 4: show the result..."
127     print 'The classify accuracy is: %.2f%%' % (accuracy * 100)
 

 

KNNTest.py

#!/usr/bin/python
# coding=utf-8

import KNN
KNN.testHandWritingClass()

 

測試結果:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM