tensorflow 變量共享


tensorflow 變量共享涉及到幾個常用的方法,tf.get_variable, tf.variable_scope, tf.reuse_variables等
為了記憶各個方法的功能,與其他方法做一個對比。

tf.variable 與 tf.get_variable

tensorflow中有兩種方法生成變量variable, 一種是tf.get_variable(), 另一種是tf.Variable()。

tf.Variable() 在定義name 相同的變量時,為了不重復變量名,會自動給變量賦一個區別於前一個變量的名字(末尾_1等)。因此使用tf.Variable() 定義的變量無法通過name屬性獲取tf.Variable對象。

tf.get_Variable()的name參數則是唯一的識別標准,因此只要通過tf.get_variables(name='xx')獲取得到的變量都是同一個變量。因此更方便於參數共享。而在重復使用時,一定要在代碼中強調scope.reuse_variables(),否則系統將會報錯。

因此,推薦不論在任何時候創建變量都使用tf.get_variable(),從而可以在任何地方對他進行共享。

tf.name_scope() 與 tf.variable_scope():

tf.name_scope()可以簡單的理解為為了更好的管理命名空間的方法。且只會影響tf.Variable()定義的變量的name。
tf.variable_scope()則可以影響到tf.get_variable()創建的對象的name。因此可以與tf.get_variable()一同使用,完成變量共享的目的,同時對命名空間進行管理。

tf.layers 參數復用

例如tf.layers.dense(), tf.layers.conv2D()等,參數復用只需要再tf.layers.dense(x, 4, name='h1', reuse=True),使得參數reuse為True,即可復用上一層的參數。
為了驗證,我們可以通過:

x = tf.ones((1, 3))
y1 = tf.layers.dense(x, 4, name='h1')
y2 = tf.layers.dense(x, 4, name='h1', reuse=True)

# y1 and y2 will evaluate to the same values
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(y1))
print(sess.run(y2))  # both prints will return the same values

觀察輸出變量是否相同的方式來判斷兩個dense層是否共享參數。

siamese LSTM networks(暹羅LSTM網絡結構)

以QA問題、答案對匹配為例:
為實現暹羅神經網絡,我們需要使用共享參數的LSTM網絡分別對問題和答案提取特征。
不同於tf.nn.dense,由於tf.nn.dynamic_rnn等是基於tf.nn.rnn_cell.LSTMCell() 構造得到的網絡結構,因此只需要讓dynamic_rnn(cell)中的cell輸入為同一個LSTMCell即可:

utterance_gru = tf.nn.rnn_cell.LSTMCell(self.rnn_units, initializer=tf.orthogonal_initializer(), 
                state_is_tuple=True,)
_, utterance_gru_embeddings = tf.nn.dynamic_rnn(utterance_gru, all_utterance_embeddings, 
                sequence_length=self.utterance_len_ph,dtype=tf.float32, scope='utterance_rnn')
utterance_gru_embeddings = utterance_gru_embeddings[1]  
 _, response_gru_embeddings= tf.nn.dynamic_rnn(utterance_gru, response_embeddings, 
               sequence_length=self.utterance_len_ph,dtype=tf.float32, scope='response_rnn')
response_gru_embeddings = response_gru_embeddings[1]
self.utt = utterance_gru_embeddings[0]
self.res = response_gru_embeddings[0] 
# if all_utterance_embeddings == response_embeddings, self.utt == self.res


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM