Python機器學習--手寫體識別(KNN+MLP)


  • MLP實現

 

 

  • 調整參數比較性能結果
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 21:14:38 2017

@author: Administrator 
"""

import numpy as np     #導入numpy工具包
from os import listdir #使用listdir模塊,用於訪問本地文件
from sklearn.neural_network import MLPClassifier  ## 版本選擇sklearn-v0.18;sklearn更新anaconda方法:conda update scikit-learn
 
#定義img2vector函數,將加載的32*32的圖片矩陣展開成一列向量
def img2vector(fileName):    
    retMat = np.zeros([1024],int) #定義返回的矩陣,大小為1*1024
    fr = open(fileName)           #打開包含32*32大小的數字文件 
    lines = fr.readlines()        #讀取文件的所有行
    for i in range(32):           #遍歷文件所有行
        for j in range(32):       #並將01數字存放在retMat中     
            retMat[i*32+j] = lines[i][j]    
    return retMat
 
 #定義加載訓練數據的函數readDataSet,並將樣本標簽轉化為one-hot向量
def readDataSet(path):    
    fileList = listdir(path)    #獲取文件夾下的所有文件 
    numFiles = len(fileList)    #統計需要讀取的文件的數目
    dataSet = np.zeros([numFiles,1024],int) #用於存放所有的數字文件
    hwLabels = np.zeros([numFiles,10])      #用於存放對應的one-hot標簽
    for i in range(numFiles):   #遍歷所有的文件
        filePath = fileList[i]  #獲取文件名稱/路徑      
        digit = int(filePath.split('_')[0])  #通過文件名獲取標簽      
        hwLabels[i][digit] = 1.0        #將對應的one-hot標簽置1
        dataSet[i] = img2vector(path +'/'+filePath) #讀取文件內容   
    return dataSet,hwLabels
 
#read dataSet
fpath='F:\RANJIEWEN\MachineLearning\Python機器學習實戰_mooc\data\手寫數字\digits\\'
train_dataSet, train_hwLabels = readDataSet(fpath+'trainingDigits')
 
# 調整參數,隱藏層數量,學習率,最大迭代次數比較性能結果
clf = MLPClassifier(hidden_layer_sizes=(100,),
                    activation='logistic', solver='adam',
                    learning_rate_init = 0.00001, max_iter=2000)
print(clf)
clf.fit(train_dataSet,train_hwLabels)
 
#read  testing dataSet
dataSet,hwLabels = readDataSet(fpath+'testDigits')
res = clf.predict(dataSet)   #對測試集進行預測
error_num = 0                #統計預測錯誤的數目
num = len(dataSet)           #測試集的數目
for i in range(num):         #遍歷預測結果
    #比較長度為10的數組,返回包含01的數組,0為不同,1為相同
    #若預測結果與真實結果相同,則10個數字全為1,否則不全為1
    if np.sum(res[i] == hwLabels[i]) < 10: 
        error_num += 1                     
print("Total num:",num," Wrong num:", \
      error_num,"  WrongRate:",error_num / float(num))
  • kNN比較
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 31 10:11:15 2017

@author: Administrator   knn-neighbors
"""

import numpy as np     #導入numpy工具包
from os import listdir #使用listdir模塊,用於訪問本地文件
from sklearn import neighbors
 
#定義img2vector函數,將加載的32*32的圖片矩陣展開成一列向量
def img2vector(fileName):    
    retMat = np.zeros([1024],int) #定義返回的矩陣,大小為1*1024
    fr = open(fileName)           #打開包含32*32大小的數字文件 
    lines = fr.readlines()        #讀取文件的所有行
    for i in range(32):           #遍歷文件所有行
        for j in range(32):       #並將01數字存放在retMat中     
            retMat[i*32+j] = lines[i][j]    
    return retMat

    
#定義加載訓練數據的函數readDataSet,並將樣本標簽轉化為one-hot向量
def readDataSet(path):    
    fileList = listdir(path)    #獲取文件夾下的所有文件 
    numFiles = len(fileList)    #統計需要讀取的文件的數目
    dataSet = np.zeros([numFiles,1024],int)    #用於存放所有的數字文件
    hwLabels = np.zeros([numFiles])#用於存放對應的標簽(與神經網絡的不同)
    for i in range(numFiles):      #遍歷所有的文件
        filePath = fileList[i]     #獲取文件名稱/路徑   
        digit = int(filePath.split('_')[0])   #通過文件名獲取標簽     
        hwLabels[i] = digit        #直接存放數字,並非one-hot向量
        dataSet[i] = img2vector(path +'/'+filePath)    #讀取文件內容 
    return dataSet,hwLabels
 

#read dataSet
fpath='F:\RANJIEWEN\MachineLearning\Python機器學習實戰_mooc\data\手寫數字\digits\\'

train_dataSet, train_hwLabels = readDataSet(fpath+'trainingDigits')
knn = neighbors.KNeighborsClassifier(algorithm='kd_tree', n_neighbors=3)
knn.fit(train_dataSet, train_hwLabels)
 
#read  testing dataSet
dataSet,hwLabels = readDataSet(fpath+'testDigits')
 
res = knn.predict(dataSet)  #對測試集進行預測
error_num = np.sum(res != hwLabels) #統計分類錯誤的數目
num = len(dataSet)          #測試集的數目
print("Total num:",num," Wrong num:", \
      error_num,"  WrongRate:",error_num / float(num))

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM