tensorflow 變量共享涉及到幾個常用的方法,tf.get_variable, tf.variable_scope, tf.reuse_variables等
為了記憶各個方法的功能,與其他方法做一個對比。
tf.variable 與 tf.get_variable
tensorflow中有兩種方法生成變量variable, 一種是tf.get_variable(), 另一種是tf.Variable()。
tf.Variable() 在定義name 相同的變量時,為了不重復變量名,會自動給變量賦一個區別於前一個變量的名字(末尾_1等)。因此使用tf.Variable() 定義的變量無法通過name屬性獲取tf.Variable對象。
tf.get_Variable()的name參數則是唯一的識別標准,因此只要通過tf.get_variables(name='xx')獲取得到的變量都是同一個變量。因此更方便於參數共享。而在重復使用時,一定要在代碼中強調scope.reuse_variables(),否則系統將會報錯。
因此,推薦不論在任何時候創建變量都使用tf.get_variable(),從而可以在任何地方對他進行共享。
tf.name_scope() 與 tf.variable_scope():
tf.name_scope()可以簡單的理解為為了更好的管理命名空間的方法。且只會影響tf.Variable()定義的變量的name。
tf.variable_scope()則可以影響到tf.get_variable()創建的對象的name。因此可以與tf.get_variable()一同使用,完成變量共享的目的,同時對命名空間進行管理。
tf.layers 參數復用
例如tf.layers.dense(), tf.layers.conv2D()等,參數復用只需要再tf.layers.dense(x, 4, name='h1', reuse=True),使得參數reuse為True,即可復用上一層的參數。
為了驗證,我們可以通過:
x = tf.ones((1, 3))
y1 = tf.layers.dense(x, 4, name='h1')
y2 = tf.layers.dense(x, 4, name='h1', reuse=True)
# y1 and y2 will evaluate to the same values
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(y1))
print(sess.run(y2)) # both prints will return the same values
觀察輸出變量是否相同的方式來判斷兩個dense層是否共享參數。
siamese LSTM networks(暹羅LSTM網絡結構)
以QA問題、答案對匹配為例:
為實現暹羅神經網絡,我們需要使用共享參數的LSTM網絡分別對問題和答案提取特征。
不同於tf.nn.dense,由於tf.nn.dynamic_rnn等是基於tf.nn.rnn_cell.LSTMCell() 構造得到的網絡結構,因此只需要讓dynamic_rnn(cell)中的cell輸入為同一個LSTMCell即可:
utterance_gru = tf.nn.rnn_cell.LSTMCell(self.rnn_units, initializer=tf.orthogonal_initializer(),
state_is_tuple=True,)
_, utterance_gru_embeddings = tf.nn.dynamic_rnn(utterance_gru, all_utterance_embeddings,
sequence_length=self.utterance_len_ph,dtype=tf.float32, scope='utterance_rnn')
utterance_gru_embeddings = utterance_gru_embeddings[1]
_, response_gru_embeddings= tf.nn.dynamic_rnn(utterance_gru, response_embeddings,
sequence_length=self.utterance_len_ph,dtype=tf.float32, scope='response_rnn')
response_gru_embeddings = response_gru_embeddings[1]
self.utt = utterance_gru_embeddings[0]
self.res = response_gru_embeddings[0]
# if all_utterance_embeddings == response_embeddings, self.utt == self.res