包括卷積神經網絡(CNN)在內的各種前饋神經網絡模型, 其一次前饋過程的輸出只與當前輸入有關與歷史輸入無關.
遞歸神經網絡(Recurrent Neural Network, RNN)充分挖掘了序列數據中的信息, 在時間序列和自然語言處理方面有着重要的應用.
遞歸神經網絡可以展開為普通的前饋神經網絡:

長短期記憶模型(Long-Short Term Memory)是RNN的常用實現. 與一般神經網絡的神經元相比, LSTM神經元多了一個遺忘門.

LSTM神經元的輸出除了與當前輸入有關外, 還與自身記憶有關. RNN的訓練算法也是基於傳統BP算法增加了時間考量, 稱為BPTT(Back-propagation Through Time)算法.
使用tensorflow內置rnn
tensorflow內置了遞歸神經網絡的實現:
from tensorflow.python.ops import rnn, rnn_cell
tensorflow目前正在快速迭代中, 上述路徑可能會發生變化.在0.6.0版本中上述路徑是有效的.
官方教程中已經加入了循環神經網絡的部分, API可能不會發生太大變化.
Tensorflow有多種rnn神經元可供選擇:
-
rnn_cell.BasicLSTMCell -
rnn_cell.LSTMCell -
rnn_cell.GRUCell
這里我們選用最簡單的BasicLSTMCell, 需要設置神經元個數和forget_bias參數:
self.lstm_cell = rnn_cell.BasicLSTMCell(hidden_n, forget_bias=1.0)
可以直接調用cell對象獲得輸出和狀態:
output, state = cell(inputs, state)
使用dropout避免過擬合問題:
from tensorflow.python.ops.rnn_cell import Dropoutwrapper
cells = DropoutWrapper(lstm_cell, input_keep_prob=0.5, output_keep_prob=0.5)
使用MultiRNNCell來創建多層神經網絡:
from tensorflow.python.ops.rnn_cell import MultiRNNCell
cells = MultiRNNCell([lstm_cell_1, lstm_cell_2])
不過rnn.rnn可以替我們完成神經網絡的構建工作:
outputs, states = rnn.rnn(self.lstm_cell, self.input_layer, dtype=tf.float32)
再加一個輸出層進行輸出:
self.prediction = tf.matmul(outputs[-1], self.weights) + self.biases
定義損失函數:
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.prediction, self.label_layer))
使用Adam優化器進行訓練:
self.trainer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss)
因為神經網絡需要處理序列數據, 所以輸入層略復雜:
self.input_layer = [tf.placeholder("float", [step_n, input_n]) for i in range(batch_size)]
tensorflow要求RNNCell的輸入為一個列表, 列表中的每一項作為一個批次進行訓練.
列表中的每一個元素代表一個序列, 每一行為序列中的一項. 這樣每一項為一個形狀為(序列長, 輸入維數)的矩陣.
標簽還是和原來一樣為形如(序列長, 輸出維度)的矩陣:
self.label_layer = tf.placeholder("float", [step_n, output_n])
執行訓練:
self.session.run(initer)
for i in range(limit):
self.session.run(self.trainer, feed_dict={self.input_layer[0]: train_x[0], self.label_layer: train_y})
因為input_layer為列表, 而列表不能作為字典的鍵.所以我們只能采用{self.input_layer[0]: train_x[0]}這樣的方式輸入數據.
可以看到lable_layer也是二維的, 並沒有輸入多個批次的數據. 考慮到這兩點, 目前這個實現並不具備多批次處理的能力.
序列的長度通常是不同的, 而目前的實現采用的是定長輸入. 這是需要解決的另一個難題.
完整源代碼可以在demo.py中查看.
