用LSTM分類 MNIST


 

    LSTM是RNN的一種算法, 在序列分類中比較有用。常用於語音識別,文字處理(NLP)等領域。 

等同於VGG等CNN模型在在圖像識別領域的位置。  本篇文章是敘述LSTM 在MNIST 手寫圖中的使用。

用來給初步學習RNN的一個范例,便於學習和理解LSTM .

    先把工作流程圖貼一下

 

代碼片段

   數據准備

def makedata():
    img_rows, img_cols = 28, 28

    mnist = fetch_mldata("MNIST original")
    # rescale the data, use the traditional train/test split
    X_1D, y_int = mnist.data / 255., mnist.target
    y = np_utils.to_categorical(y_int, num_classes=10)

    X = X_1D.reshape(X_1D.shape[0], img_rows, img_cols )

    input_shape = (img_rows, img_cols, 1)
    x_train, x_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]

    return X, y
    pass

下載 MNIST數據, 進行歸一化  mnist.data / 255, 把數據[7000,784 ] 轉成[ 70000,28,28] 

 

構建模型:

def buildlstm():

    import numpy as np

    data_dim = 28
    timesteps = 28
    num_classes = 10

    # expected input data shape: (batch_size, timesteps, data_dim)
    model = Sequential()
    model.add(LSTM(32, return_sequences=True,   input_shape=(timesteps, data_dim+14)))   
    model.add(LSTM(32, return_sequences=True))  
    model.add(LSTM(32))  
    model.add(Dense(10, activation='softmax'))

    model.compile(loss='categorical_crossentropy',
                  optimizer='rmsprop',
                  metrics=['accuracy'])
    print model.summary()
    return  model
    pass

基礎參數: data_dim, timesteps, num_classes   分別為 28,28, 10
網絡層級 :    LSTM ----》LSTM ----》LSTM ----》Dense
注意點: input_shape=(timesteps, data_dim+14))   此處 應該為  data_dim , data_dim+14是我做第二個試驗使用。
網絡理解: RNN是用前一部分數據對當前數據的影響,並共同作用於最后結果。 用基礎的深度神經網絡(只有Dense層),是把MNIST一個圖形,
提取成784個像素數據,把784個數據扔給神經網絡,784個數據是同等的概念。 訓練出權重來確定最終的分類值。   

RNN 之於MNIST, 是把MNIST 分成 28x28 數據。可以理解為用一個激光掃描一個圖片,掃成28個(行)數據, 每行為28個像素。 站在時間序列
的角度,其實圖片沒有序列概念。但是我們可以這樣理解, 每一行於下一行是有位置關系的,不能進行順序變化。 比如一個手寫 “7”字, 如果把28行
的上下行順序打亂, 那么7 上面的一橫就可能在中間位置,也可能在下面的位置。  這樣,最終的結果就不應該是 7 .  
所以MNIST 的 28x28可以理解為 有時序關系的數據。 

訓練預測:

def runTrain(model, x_train, x_test, y_train, y_test):
    model.fit(x_train, y_train,  batch_size= nbatch_size, epochs= nEpoches)
    score = model.evaluate(x_test, y_test, batch_size=nbatch_size)
    print 'evaluate score:', score
    pass

這部分應該沒什么好說的

主程序:

def test():

    X,y = makedata2()
    x_train, x_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]
    model = buildlstm()
    runTrain(model, x_train, x_test, y_train, y_test )
    pass


運行結果

結構:
Layer (type)                 Output Shape              Param #
=================================================================
lstm_1 (LSTM)                (None, 28, 32)            7808
_________________________________________________________________
lstm_2 (LSTM)                (None, 28, 32)            8320
_________________________________________________________________
lstm_3 (LSTM)                (None, 32)                8320
_________________________________________________________________
dense_1 (Dense)              (None, 10)                330
=================================================================
Total params: 24,778
Trainable params: 24,778
Non-trainable params: 0
_________________________________________________________________


結果:
base    lstm for mnist
acc : 98.56%

結果2:
把數據最后增加 50%  的 0 , (dim X 0.5)
acc : 98.39%
結果基本上 與原數據一致

 

該實驗證明兩個結論:
1.  LSTM可用於圖形識別
2.  在數據中 每行28個基礎像素后面 + 14 個空白(0)的元素,不影分類識別。 


寫在最后:  本實驗的目的是為了理解RNN(LSTM),  只有理解了才能很好的使用。 本文章的目的是為記錄和分享。
再說下 RNN在其它領域的應用。  比如在語音識別領域,一個音譜,識別成一個單詞(詞語),可以理解成一個
豎向掃描的MNIST ,   一個股票的K線圖,也可以理解一個豎向掃描的MNIST。  還有其它領域,可以歸納遞推。 
入門之后, 如何在自己的領域,再深入(構建復雜模型,優化數據的處理),提高網絡模型的識別准確,那需要
見仁見智的。 

代碼文件鏈接:

源碼下載

 
有對 金融程序化 和 深度學習結合有興趣的可以加群 , 個人群: 杭州程序化交易群  375129936


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM