使用KNN對MNIST數據集進行實驗


由於KNN的計算量太大,還沒有使用KD-tree進行優化,所以對於60000訓練集,10000測試集的數據計算比較慢。這里只是想測試觀察一下KNN的效果而已,不調參。

K選擇之前看過貌似最好不要超過20,因此,此處選擇了K=10,距離為歐式距離。如果需要改進,可以再調整K來選擇最好的成績。

先跑了一遍不經過scale的,也就是直接使用像素灰度值來計算歐式距離進行比較。發現開始基本穩定在95%的正確率上,嚇了一跳。因為本來覺得KNN算是沒有怎么“學習”的機器學習算法了,猜測它的特點可能會是在任何情況下都可以用,但都表現的不是最好。所以估計在60%~80%都可以接受。沒想到能基本穩定在95%上,確定算法和代碼沒什么問題后,突然覺得是不是這個數據集比較沒挑戰性。。。

去MNIST官網(http://yann.lecun.com/exdb/mnist/),上面掛了以該數據集為數據的算法的結果比較。查看了一下KNN,發現有好多,而且錯誤率基本都在5%以內,甚至能做到1%以內。唔。

跑的結果是,正確率:96.687%。也就是說,錯誤率error rate為3.31%左右。

再跑一下經過scale的數據,即對灰度數據歸一化到[0,1]范圍內。看看效果是否有所提升。

經過scale,最終跑的結果是,正確率:竟然也是96.687%! 也就是說,對於該數據集下,對KNN的數據是否進行歸一化並無效果!

在跑scale之前,個人猜測:由於一般對數據進行處理之前都進行歸一化,防止高維詛咒(在784維空間中很容易受到高維詛咒)。因此,預測scale后會比前者要好一些的。但是,現在看來二者結果相同。也就是說,對於K=10的KNN算法中,對MNIST的預測一樣的。

對scale前后的正確率相同的猜測:由於在訓練集合中有60000個數據點,因此0-9每個分類平均都有6000個數據點,在這樣的情況下,對於測試數據集中的數據點,相臨近的10個點中大部分都是其他分類而導致分類錯誤的概率會比較地(畢竟10相對與6000來說很小),所以,此時,KNN不僅可以取得較好的分類效果,而且對於是否scale並不敏感,效果相同。

代碼如下:

  1. #KNN for MNIST  
  2. from numpy import *  
  3. import operator  
  4.   
  5. def line2Mat(line):  
  6.     line = line.strip().split(' ')  
  7.     label = line[0]  
  8.     mat = []  
  9.     for pixel in line[1:]:  
  10.         pixel = pixel.split(':')[1]  
  11.         mat.append(float(pixel))  
  12.     return mat, label  
  13.   
  14. #matrix should be type: array. Or classify() will get error.  
  15. def file2Mat(fileName):  
  16.     f = open(fileName)  
  17.     lines = f.readlines()  
  18.     matrix = []  
  19.     labels = []  
  20.     for line in lines:  
  21.         mat, label = line2Mat(line)  
  22.         matrix.append(mat)  
  23.         labels.append(label)  
  24.     print 'Read file '+str(fileName) + ' to matrix done!'  
  25.     return array(matrix), labels  
  26.   
  27. #classify mat with trained data: matrix and labels. With KNN's K set.  
  28. def classify(mat, matrix, labels, k):  
  29.     diffMat = tile(mat, (shape(matrix)[0], 1)) - matrix  
  30.     #diffMat = array(diffMat)  
  31.     sqDiffMat = diffMat ** 2  
  32.     sqDistances = sqDiffMat.sum(axis=1)  
  33.     distances = sqDistances ** 0.5  
  34.     sortedDistanceIndex = distances.argsort()  
  35.     classCount = {}  
  36.     for i in range(k):  
  37.         voteLabel = labels[sortedDistanceIndex[i]]  
  38.         classCount[voteLabel] = classCount.get(voteLabel,0) + 1  
  39.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),reverse=True)  
  40.     return sortedClassCount[0][0]  
  41.       
  42. def classifyFiles(trainMatrix, trainLabels, testMatrix, testLabels, K):  
  43.     rightCnt = 0  
  44.     for i in range(len(testMatrix)):  
  45.         if i % 100 == 0:  
  46.             print 'num '+str(i)+'. ratio: '+ str(float(rightCnt)/(i+1))  
  47.         label = testLabels[i]  
  48.         predictLabel = classify(testMatrix[i], trainMatrix, trainLabels, K)  
  49.         if label == predictLabel:  
  50.             rightCnt += 1  
  51.     return float(rightCnt)/len(testMatrix)  
  52.   
  53. trainFile = 'train_60k.txt'  
  54. testFile = 'test_10k.txt'  
  55. trainMatrix, trainLabels = file2Mat(trainFile)  
  56. testMatrix, testLabels = file2Mat(testFile)  
  57. K = 10  
  58. rightRatio = classifyFiles(trainMatrix, trainLabels, testMatrix, testLabels, K)  
  59. print 'classify right ratio:' +str(right)  

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM