K-鄰近分類算法——分類MNIST手寫體數據算法(機器學習實戰)


  k 近鄰法(K-nearest neighbor, KNN)是一種基本分類於回歸方法,其在1968年由Cover和Hart提出的。k 近鄰算法采用測量不同特征值之間的距離方法進行分類。其輸入為示例的特征向量,對應於特征空間的點;輸出為實例的類別,可以取多類。

  k 近鄰法假設給定一個訓練數據集,其中的上實例類別已定,分類時,對新的實例,根據其K 個最近鄰的訓練實力的類別,通過多數表決等方式進行預測。k 近鄰法實際上利用訓練數據集對特征向量哦那關鍵進行划分,並作為其分類的“模型”。

  k 鄰近法的基本三要素為: k 值的選擇、距離度量以及分類決策規則。

實例:k-近鄰算法的手寫識別系統

1.收集數據:提供文本文件。

2.准備數據:編寫函數classify0() ,將圖像格式轉換為分類器使用的list格式。

3.分析數據:在Python命令提示符中檢查數據,確保它符合要求。

4.測試算法:編寫函數使用提供的部分數據集作為測試樣本,測試樣本與非測試樣本的區別在於測試樣本是已經完成分類的數據,如果預測分類與實際類別不同,則標記為一個錯誤。

步驟:

1、該手寫體數據集合修改自"手寫數字數據集的光學識別"一文中的數據集合,該文登載於2010年10月3日的UCI機器學習資料庫中 http://archive.ics.uci.edu/ml。

2、准備數據

2.1 編寫一段函數img2vector ,將圖像轉換為向量:該函數創建1x1024的NumPy數組,然后打開給定的文件,循環讀出文件的前32行,並將每行的頭32個字符值存儲在NumPy數組中,最后返回數組。

def img2vector(filename):
    returnVect = zeros((1,1024)) #創建1*1024的數組
    fr = open(filename) #打開文件
    #循環讀出文件的前32行,並將每行的頭32個字符值存儲在NumPy數組中
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32*i+j] = int(lineStr[j])
    return returnVect #返回數組

2.2 准備分類:k 鄰近算法,此處距離度量采用歐式距離度量;classify0() 函數有4個輸入參數:用於分類的輸入向量是inX,輸入的訓練樣本集為dataSet,標簽向量為labels ,最后的參數k 表示用於選擇最近鄰居的數目,其中標簽向量的元素數目和矩陣dataSet 的行數相同。

def classify0(inX, dataSet, labels, k):

    dataSetSize = dataSet.shape[0]

    #❶(以下三行)距離計算(歐氏距離)

    diffMat = tile(inX, (dataSetSize,1)) - dataSet

    sqDiffMat = diffMat**2

    sqDistances = sqDiffMat.sum(axis=1)

    distances = sqDistances**0.5

    sortedDistIndicies = distances.argsort()     

    classCount={}   

    #❷ (以下兩行)選擇距離最小的k個點     

    for i in range(k):

         voteIlabel = labels[sortedDistIndicies[i]]

         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1

    sortedClassCount = sorted(classCount.iteritems(), 

      #❸ 排序

      key=operator.itemgetter(1), reverse=True)

    return sortedClassCount[0][0]

距離度量:歐式距離、曼哈頓距離、切比雪夫距離;這三種分別為閔可夫斯基度量的2范、1范、∞。將上述classify0()中內容關於距離距離內容修改

3、使用k 鄰近識別手寫體數據

函數handwritingClassTest() 是測試分類器的代碼,將其寫入kNN.py文件中。在寫入這些代碼之前,我們必須確保將from os import listdir 寫入文件的起始部分,這段代碼的主要功能是從os模塊中導入函數listdir ,它可以列出給定目錄的文件名。

def handwritingClassTest():

    hwLabels = []

    trainingFileList = listdir('trainingDigits')           #❶ 獲取目錄內容

    m = len(trainingFileList)

    trainingMat = zeros((m,1024))

    for i in range(m):

          #❷ (以下三行)從文件名解析分類數字

        fileNameStr = trainingFileList[i]

        fileStr = fileNameStr.split('.')[0]               

        classNumStr = int(fileStr.split('_')[0])

        hwLabels.append(classNumStr)

        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)

    testFileList = listdir('testDigits')        

    errorCount = 0.0

    mTest = len(testFileList)

    for i in range(mTest):

        fileNameStr = testFileList[i]

        fileStr = fileNameStr.split('.')[0]     

        classNumStr = int(fileStr.split('_')[0])

        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)

        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)

        print "the classifier came back with: %d, the real answer is: %d"\% (classifierResult, classNumStr)

        if (classifierResult != classNumStr): errorCount += 1.0

    print "\nthe total number of errors is: %d" % errorCount

    print "\nthe total error rate is: %f" % (errorCount/float(mTest))

4、實驗比較不同K值影響,距離度量等影響

1 不同K值以及距離度量的錯誤率

距離度量

K=3(error rate)

K=5(error rate)

8(error rate)

歐式距離

0.010571

0.017970

0.019027

曼哈頓距離

0.010571

0.017970

0.019027

切比雪夫

0.908034

0.908034

0.908034

整體算法以及部分實驗結果如下所示:

from numpy import *
import operator
from os import listdir

def createDataSet():
    groups = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return groups,labels

def classify0(inX, dataSet, labels, k ):
    dataSetSize = dataSet.shape[0]
    #計算距離
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis = 1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()
    classCount = {}
    #選擇距離最近的
    for i in range(k):
        votelabel = labels[sortedDistIndicies[i]]
        classCount[votelabel] = classCount.get(votelabel,0) + 1
    sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse=True)
    #排序
    return sortedClassCount[0][0]

#圖像轉向量
def img2vector(filename):
    returnVect = zeros((1,1024)) #創建1*1024的數組
    fr = open(filename) #打開文件
    #循環讀出文件的前32行,並將每行的頭32個字符值存儲在NumPy數組中
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32*i+j] = int(lineStr[j])
    return returnVect #返回數組

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('digits/trainingDigits')#獲取目錄內容;trainingDigits目錄中的文件內容存儲在列表中
    m = len(trainingFileList)#得到目錄中有多少文件,並將其存儲在變量m 中。
    trainingMat = zeros((m, 1024))
    for i in range(m):
        #從文件名解析分類數:接着,代碼創建一個m 行1024列的訓練矩陣,該矩陣的每行數據存儲一個圖像。
        # 我們可以從文件名中解析出分類數字
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)#類代碼存儲在hwLabels 向量中
        trainingMat[i,:] = img2vector('digits/trainingDigits/%s' % fileNameStr)
    testFileList = listdir('digits/testDigits')#獲取目錄內容;testDigits目錄中的文件內容存儲在列表中
    errorCount = 0.0
    mTest = len(testFileList)#得到目錄中有多少文件,並將其存儲在變量mTest 中。
    #testDigits目錄中的文件執行相似的操作,
    # 不同之處是我們並不將這個目錄下的文件載入矩陣中,
    # 而是使用classify0() 函數測試該目錄下的每個文件
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels,3)
        print('the classifier came back with: %d, the real answer is: %d' % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0
        print('\nthe total number of errors is: %d' % errorCount)
        print('\nthe total error rate is: %f' % (errorCount / float(mTest)))

test = handwritingClassTest()

算法結果如下:

 參考文獻:

[1]peter harrington,機器學習實戰[M].

[2]李航,統計學習方法(第二版)[M].

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM