【機器學習】手寫數字識別算法


1.數據准備

樣本數據獲取忽略,實際上就是將32*32的圖片上數字格式化成一個向量,如下:

 

本demo所有樣本數據都是基於這種格式的

訓練數據:將圖片數據轉成1*1024的數組,作為一個訓練數據。

訓練數據集:https://github.com/zimuqi/machine_Learning/tree/master/ch02/trainingDigits

測試數據集:https://github.com/zimuqi/machine_Learning/tree/master/ch02/testDigits

樣本的文件名格式為:真實值_xxx.txt

轉換代碼:

1 def img2vector(filename):
2     returnVect=zeros((1,1024))
3     fr=open(filename)
4     for i in range(32):
5         lineStr=fr.readline()
6         for j in range(32):
7             returnVect[0,32*i+j]=int(lineStr[j])
8     return returnVect

 

2.測試算法

 1 def handwritingClassTest():
 2     hwLabels=[]    # 訓練樣本的標簽數組
 3     traningFileList=listdir("trainingDigits")    # 獲取所有的訓練樣本目錄下的文件名
 4     m=len(traningFileList)
 5     traningMat=zeros((m,1024))    # 初始化訓練樣本數列
 6 
 7     for i in range(m):
 8         fileNameStr=traningFileList[i]    # 獲取文件名
 9         fileStr=fileNameStr.split(".")[0]   
10         clasNumStr=int(fileStr.split("_")[0])    # 獲取樣本的實際值 放入標簽數組
11         hwLabels.append(clasNumStr)
12         traningMat[i,:]=img2vector("trainingDigits/{}".format(fileNameStr))    # 將樣本轉化成1*1024的行放入訓練樣本數列
13 
14     testFileList=listdir("testDigits")    # 測試樣本目錄
15     error=0
16     mtest=len(testFileList)
17     for i in range(mtest):
18         fileNameStr=testFileList[i]
19         fileStr=fileNameStr.split(".")[0]
20         clasNumStr=int(fileStr.split("_")[0])
21         testMat=img2vector("testDigits/{}".format(fileNameStr))
22         res=classify(testMat,traningMat,hwLabels,3)     # 使用分類器分類
23         print "came bank with:{} the real anwser is:{}".format(clasNumStr,res)
24         if clasNumStr!=res:    # 對比與真實的結果 計算錯誤率
25             error+=1
26 
27     print "total:{}".format(mtest)
28     print "error:{}".format(error)
29     print "error:{}".format(float(error/mtest))

這個案例中 算法的識別率為:98.84%

classify是分類器 上上一篇文章中有寫到,具體了解可以點擊這里

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM