BP神經網絡的手寫數字識別


BP神經網絡的手寫數字識別

 

 

原創實踐:基於BP神經網絡的手寫數字識別

ANN 人工神經網絡算法在實踐中往往給人難以琢磨的印象,有句老話叫“出來混總是要還的”,大概是由於具有很強的非線性模擬和處理能力,因此作為代價上帝讓它“黑盒”化了。作為一種general purpose的學**算法,如果你實在不想去理會其他類型算法的理論基礎,那就請使用ANN吧。本文為筆者使用BP神經網絡進行手寫數字識別的整體思路和算法實現,由於近年來神經網絡在深度學**,尤其是無監督特征學**上的成功,理解神經網絡的實現機制也許可以讓“黑盒”變得不再神秘。

首先,作為一篇面向機器學**愛好者的文章,基本的理論介紹還是必要的。BP (Back Propagation)神經網絡,即誤差反傳誤差反向傳播算法的學**過程,由信息的正向傳播和誤差的反向傳播兩個過程組成。輸入層各神經元負責接收來自外界的輸入信息,並傳遞給中間層各神經元;中間層是內部信息處理層,負責信息變換,根據信息變化能力的需求,中間層可以設計為單隱層或者多隱層結構;最后一個隱層傳遞到輸出層各神經元的信息,經進一步處理后,完成一次學**的正向傳播處理過程,由輸出層向外界輸出信息處理結果。當實際輸出與期望輸出不符時,進入誤差的反向傳播階段。誤差通過輸出層,按誤差梯度下降的方式修正各層權值,向隱層、輸入層逐層反傳。周而復始的信息正向傳播和誤差反向傳播過程,是各層權值不斷調整的過程,也是神經網絡學**訓練的過程,此過程一直進行到網絡輸出的誤差減少到可以接受的程度,或者預先設定的學**次數為止。

網絡結構

輸入層有n個神經元,隱含層有p個神經元,輸出層有q個神經元。

變量定義

參數訓練

算法流程

圖像預處理與歸一化

好了,理論介紹完了,接下來看看如何識別手寫數字。

輸入樣本示例

輸入樣本為書寫數字的圖像,數據下載百度網盤鏈接:http://yun.baidu.com/s/1nt3UewD

本文采用逐像素特征提取的方法提取數字樣本的特征向量。將像素點RGB值之和大於255的像素點,特征值設為1,反之設為0。歸一化的圖像生成一個28x28的布爾矩陣,依次取每列元素,轉化為784x1的列矩陣,作為輸入圖像的特征向量。

模型訓練與模型識別

BP神經網絡模型參數設置

樣本描述:

受訓練樣本限制,本文僅對0,1,2,3,4這5個數字進行識別。每個數字對應訓練樣本與測試樣本數量如下表

模型評估

 

程序設計參考(Java)

public class AnnModel:前反饋神經網絡模型

主要成員變量

public HiddenNeuron[] hiddenLayer;

//隱含層神經元數組

public OutputNeuron[] outputLayer;

//輸出層神經元數組

publicintinputLayerSize,hiddenLayerSize,outputLayerSize;

//輸入層神經元個數、隱含層神經元個數、輸出層神經元個數

主要方法

public BPAnnModel(intinputLayerSize,inthiddenLayerSize,int outputLayerSize)

//根據輸入層、隱含層與輸出層神經元個數,初始化隱含層與輸出層各神經元,構造網絡模型

publicdouble[] calculateHiddenLayerOutput(double[] inputLayer)

//根據輸入層計算隱含層輸出

publicdouble[] calculateModelOutput(double[] inputLayer)

//根據輸入層計算輸出層輸出

public String getOutputClass(double[] inputLayers)

//根據輸入層預測分類結果

publicvoidsave(String modelSavePath)

//將模型的訓練結果保存到文本文件中

publicstatic BPAnnModel load(String modelSavePath)

//從保存訓練結果的文本文件中讀取模型參數,創建神經網絡模型實例

 

public class Neuron:神經元超類

主要成員變量

publicdoubletheta=0.2;

//theta 為神經元興奮度閾值

publicdouble[] weight;

//權值向量

publicdoublemu=0.1;

//mu為梯度下降的學**速率

publicintidx;

//idx為神經元在其所在層的索引位置

主要方法

public Neuron(intweightArrLength,intidx)

//根據權值向量長度初始化權值向量,構造神經元實例,this.idx=idx

publicdouble calculateActivation(double[] layerInput)

//根據神經元輸入計算凈活躍度

publicdouble calculateOutput(double[]layerInput)

//根據神經元輸入計算神經元輸出

publicdouble logisticSigmod(doublefixedActivation)

//激活函數

publicdouble logisticSigmod_1(doubleoutput)

//激活函數一階導數

public class HiddenNeuron:隱含層神經元,Neuron子類

publicvoidupdateWeight(double[] inputLayer, double[]yArr,

double[] tmpHiddenLayerOutput,

double[] tmpModelOutput)

//根據網絡模型輸入層、期望輸出、隱含層輸出與輸出層修正權值向量

public class OutputNeuron:輸出層神經元,Neuron子類

publicvoidupdateWeight(double[] inputLayer, double[]yArr)

//根據網絡模型輸入層和期望輸出修正權值向量

public class BPBP神經網絡訓練主程序

主要成員變量

publicdoubleepsilon=0.005;

//訓練停止的最小誤差

publicintcurTrainingTimes=0,maxTrainingTimes=50000;

//當前訓練次數與最大訓練次數

public BPAnnModel model;

//神經網絡模型

publicdouble[][] trainData;

//訓練數據集

主要方法

publicvoidsetModel(BPAnnModel model)

//加載權值已訓練好的神經網絡模型

publicvoidsetTrain(String trainPath,String varType,Stringdelimiter)

//輸入訓練數據文件與文件格式生成訓練數據集

publicvoidtrain()

//執行訓練過程

publicdouble calculateSingleError(double[] inputLayer,double[] yArr,boolean standardizeTrans)

//計算單訓練點誤差

publicdouble calculateGlobalError()

//計算全局誤差

注意事項

1.輸入樣本歸一化處理 由於模型輸入數據的單位不一樣,有些數據的范圍可能特別大,導致的結果是神經網絡收斂慢、訓練時間長、模型不穩定。因此在進行訓練之前通常需要對數據進行歸一化處理。為了統一各層神經元的輸入標准,采用S形激活函數時,激活函數的值域為 (0,1),通常采用max&min的方法將變量映射成 (0,1) 之間的數;采用雙極S形激活函數時,激活函數的值域為 (-1,1),通常采用2(max&min)-1的方法將變量映射成 (-1,1) 之間的數。

2.減少計算全局誤差的頻次 每次迭代過程中,權值修正的效率只和網絡的規模有關系,而計算全局誤差消耗的時間則和樣本量的大小直接成正比。在樣本量較大的情況下,頻繁的計算全局誤差會帶來及大的開銷。為了提高訓練效率,可以采取批量式的訓練方法,即在第六步完成后,直接跳回第二步,進行指定次迭代后,再進入到第七步進行全局誤差檢驗。

3.輸出層定義 當目標變量有m個分類時,輸出層的神經元個數通常有有兩種定義方法,m或以2為底的log m。建議采用m作為輸出層神經元個數,一方面可以很方便地從輸出層輸出信號映射到分類值;另一方面可以給出分類結論的“置信度”。采用以2為底的log m作為輸出層神經元個數時,通常采用輸出信號的二進制表示去映射指定分類。

4.學**效率(梯度下降的步長)參數設定

學**效率直接影響着網絡收斂的速度,以及網絡能否收斂。學**效率設置偏小可以保證網絡收斂,但是收斂較慢;反之則有可能使網絡訓練不收斂,影響識別效果。因此可以在誤差快速下降后放緩學**效率,增強模型穩定性。

5.過擬合問題 神經網絡計算不能一味地追求訓練誤差最小,這樣很容易出現“過擬合”現象,只要能夠實時檢測誤差率的變化就可以確定最佳的訓練次數,比如在本文中25000次左右的學**次數即可在測試數據上達到最優效果,25000次之后的學**,不僅徒增計算量,而且還有過擬合的風險。

 
 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM