超詳細的RNN代碼實現(tensorflow)


一、學習單步的RNN:RNNCell

如果要學習TensorFlow中的RNN,第一站應該就是去了解“RNNCell”,它是TensorFlow中實現RNN的基本單元,每個RNNCell都有一個call方法,使用方式是:(output, next_state) = call(input, state)。
也就是說,每調用一次RNNCell的call方法,就相當於在時間上“推進了一步”,這就是RNNCell的基本功能。

在代碼實現上,RNNCell只是一個抽象類,我們用的時候都是用的它的兩個子類BasicRNNCell和BasicLSTMCell。顧名思義,前者是RNN的基礎類,后者是LSTM的基礎類。

找到源碼中BasicRNNCell的調用函數實現:

def調用(self,inputs,state): “”“最基本的RNN:output = new_state = act(W * input + U * state + B)。”“” output = self._activation(_linear([inputs,state] ,self._num_units,True)) return 輸出,輸出

"return輸出,輸出”說明在BasicRNNCell中,輸出其實和隱狀態的值是一樣的。因此還需要額外對輸出定義新的變換才能得到真正的輸出y。由於輸出和隱狀態是一回事,所以在BasicRNNCell中,state_size永遠等於output_size

除了call方法外,對於RNNCell,還有兩個類屬性比較重要:

state_size
output_size
前者是隱層的大小,后者是輸出的大小。比如我們通常是將一個batch送入模型計算,設輸入數據的形狀為(batch_size, input_size),那么計算時得到的隱層狀態就是(batch_size, state_size),輸出就是(batch_size, output_size)。

對於單層RNN:

import tensorflow as tf import numpy as np cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128) # state_size = 128 print(cell.state_size) # 128 inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size,100是input_size shape = (batch_size, input_size) h0 = cell.zero_state(32, np.float32) # 通過zero_state得到一個全0的初始狀態,形狀為(batch_size, state_size) output, h1 = cell.call(inputs, h0) #調用call函數 print(h1.shape) # (32, 128)

對於多層RNN:

import tensorflow as tf import numpy as np num_layers = 2  #層數
hidden_size = [128,256] #每一層的隱節點個數(可以不一樣)
rnn_cells = [] #包含所有層的列表

for i in range(num_layers): # 構建一個基本rnn單元(一層)
    rnn_cell = tf.nn.rnn_cell.BasicRNNCell(lstm_size[i]) # 可以添加dropout
    drop_cell = tf.nn.rnn_cell.DropoutWrapper(rnn_cell , output_keep_prob=keep_prob) rnn_cells.append(drop_cell) # 堆疊多個LSTM單元
    cell = tf.nn.rnn_cell.MultiRNNCell(rnn_cells) initial_state = cell.zero_state(batch_size, tf.float32) return cell, initial_state ''' 注:對於老版本的tensorflow,堆疊多層RNN(或LSTM): cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell for _ in range(num_layers)]) '''

 

對於BasicLSTMCell,情況有些許不同,因為LSTM可以看做有兩個隱狀態h和c,對應的隱層就是一個Tuple,每個都是(batch_size, state_size)的形狀:

import tensorflow as tf import numpy as np lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128) inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size h0 = lstm_cell.zero_state(32, np.float32) # 通過zero_state得到一個全0的初始狀態 output, h1 = lstm_cell.call(inputs, h0) #h1包含兩個隱狀態 print(h1.h) # shape=(32, 128) print(h1.c) # shape=(32, 128)

二、學習如何一次執行多步:tf.nn.dynamic_rnn

基礎的RNNCell有一個很明顯的問題:對於單個的RNNCell,我們使用它的call函數進行運算時,只是在序列時間上前進了一步。比如使用x1、h0得到h1,通過x2、h1得到h2等。這樣的h話,如果我們的序列長度為10,就要調用10次call函數,比較麻煩。對此,TensorFlow提供了一個tf.nn.dynamic_rnn函數,使用該函數就相當於調用了n次call函數。即通過{h0,x1, x2, …., xn}直接得{h1,h2…,hn}。

具體來說,設我們輸入數據的格式為(batch_size, time_steps, input_size),其中time_steps表示序列本身的長度,如在Char RNN中,長度為10的句子對應的time_steps就等於10。最后的input_size就表示輸入數據單個序列單個時間維度上固有的長度。另外我們已經定義好了一個RNNCell,調用該RNNCell的call函數time_steps次,對應的代碼就是:

# inputs: shape = (batch_size, time_steps, input_size) # cell: RNNCell # initial_state: shape = (batch_size, cell.state_size)。初始狀態。一般可以取零矩陣

# inputs: shape = (batch_size, time_steps, input_size)
# cell: RNNCell
# initial_state: shape = (batch_size, cell.state_size)。初始狀態。一般可以取零矩陣

 
         

import tensorflow as tf

 
         

tf.reset_default_graph()
batch_size = 32 # batch大小
input_size = 100 # 輸入向量xt維度
state_size = 128 # 隱藏狀態ht維度
time_steps = 10 # 序列長度

 
         

inputs = tf.random_normal(shape=[batch_size, time_steps, input_size], dtype=tf.float32)
print("inputs.shape:",inputs.shape) #(32,10,100)

 
         

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units = state_size)
print(lstm_cell.state_size) #(c=128,h=128)

 
         

initial_state = lstm_cell.zero_state(batch_size, dtype = tf.float32)
print(initial_state.h, initial_state.c)  #(32,128),(32,128)

 
         

outputs, state = tf.nn.dynamic_rnn(lstm_cell, inputs, initial_state = initial_state)

 
         

print(outputs)  #(32,10,128)
print(state)    #(32,128) state是最終(最后一個time_step)的狀態
print(state.h, state.c) #(32,128),(32,128)

 

 

轉自知乎:https://zhuanlan.zhihu.com/p/28196873


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM