LSTM是RNN的一種算法, 在序列分類中比較有用。常用於語音識別,文字處理(NLP)等領域。
等同於VGG等CNN模型在在圖像識別領域的位置。 本篇文章是敘述LSTM 在MNIST 手寫圖中的使用。
用來給初步學習RNN的一個范例,便於學習和理解LSTM .
先把工作流程圖貼一下:
代碼片段 :
數據准備
def makedata(): img_rows, img_cols = 28, 28 mnist = fetch_mldata("MNIST original") # rescale the data, use the traditional train/test split X_1D, y_int = mnist.data / 255., mnist.target y = np_utils.to_categorical(y_int, num_classes=10) X = X_1D.reshape(X_1D.shape[0], img_rows, img_cols ) input_shape = (img_rows, img_cols, 1) x_train, x_test = X[:60000], X[60000:] y_train, y_test = y[:60000], y[60000:] return X, y pass
下載 MNIST數據, 進行歸一化 mnist.data / 255, 把數據[7000,784 ] 轉成[ 70000,28,28]
構建模型:
def buildlstm(): import numpy as np data_dim = 28 timesteps = 28 num_classes = 10 # expected input data shape: (batch_size, timesteps, data_dim) model = Sequential() model.add(LSTM(32, return_sequences=True, input_shape=(timesteps, data_dim+14))) model.add(LSTM(32, return_sequences=True)) model.add(LSTM(32)) model.add(Dense(10, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) print model.summary() return model pass
基礎參數: data_dim, timesteps, num_classes 分別為 28,28, 10
網絡層級 : LSTM ----》LSTM ----》LSTM ----》Dense
注意點: input_shape=(timesteps, data_dim+14)) 此處 應該為 data_dim , data_dim+14是我做第二個試驗使用。
網絡理解: RNN是用前一部分數據對當前數據的影響,並共同作用於最后結果。 用基礎的深度神經網絡(只有Dense層),是把MNIST一個圖形,
提取成784個像素數據,把784個數據扔給神經網絡,784個數據是同等的概念。 訓練出權重來確定最終的分類值。
RNN 之於MNIST, 是把MNIST 分成 28x28 數據。可以理解為用一個激光掃描一個圖片,掃成28個(行)數據, 每行為28個像素。 站在時間序列
的角度,其實圖片沒有序列概念。但是我們可以這樣理解, 每一行於下一行是有位置關系的,不能進行順序變化。 比如一個手寫 “7”字, 如果把28行
的上下行順序打亂, 那么7 上面的一橫就可能在中間位置,也可能在下面的位置。 這樣,最終的結果就不應該是 7 .
所以MNIST 的 28x28可以理解為 有時序關系的數據。
訓練預測:
def runTrain(model, x_train, x_test, y_train, y_test): model.fit(x_train, y_train, batch_size= nbatch_size, epochs= nEpoches) score = model.evaluate(x_test, y_test, batch_size=nbatch_size) print 'evaluate score:', score pass
這部分應該沒什么好說的
主程序:
def test(): X,y = makedata2() x_train, x_test = X[:60000], X[60000:] y_train, y_test = y[:60000], y[60000:] model = buildlstm() runTrain(model, x_train, x_test, y_train, y_test ) pass
運行結果:
結構: Layer (type) Output Shape Param # ================================================================= lstm_1 (LSTM) (None, 28, 32) 7808 _________________________________________________________________ lstm_2 (LSTM) (None, 28, 32) 8320 _________________________________________________________________ lstm_3 (LSTM) (None, 32) 8320 _________________________________________________________________ dense_1 (Dense) (None, 10) 330 ================================================================= Total params: 24,778 Trainable params: 24,778 Non-trainable params: 0 _________________________________________________________________ 結果: base lstm for mnist acc : 98.56% 結果2: 把數據最后增加 50% 的 0 , (dim X 0.5) acc : 98.39% 結果基本上 與原數據一致
該實驗證明兩個結論:
1. LSTM可用於圖形識別
2. 在數據中 每行28個基礎像素后面 + 14 個空白(0)的元素,不影分類識別。
寫在最后: 本實驗的目的是為了理解RNN(LSTM), 只有理解了才能很好的使用。 本文章的目的是為記錄和分享。
再說下 RNN在其它領域的應用。 比如在語音識別領域,一個音譜,識別成一個單詞(詞語),可以理解成一個
豎向掃描的MNIST , 一個股票的K線圖,也可以理解一個豎向掃描的MNIST。 還有其它領域,可以歸納遞推。
入門之后, 如何在自己的領域,再深入(構建復雜模型,優化數據的處理),提高網絡模型的識別准確,那需要
見仁見智的。
代碼文件鏈接:
有對 金融程序化 和 深度學習結合有興趣的可以加群 , 個人群: 杭州程序化交易群 375129936