tensorflow 变量共享


tensorflow 变量共享涉及到几个常用的方法,tf.get_variable, tf.variable_scope, tf.reuse_variables等
为了记忆各个方法的功能,与其他方法做一个对比。

tf.variable 与 tf.get_variable

tensorflow中有两种方法生成变量variable, 一种是tf.get_variable(), 另一种是tf.Variable()。

tf.Variable() 在定义name 相同的变量时,为了不重复变量名,会自动给变量赋一个区别于前一个变量的名字(末尾_1等)。因此使用tf.Variable() 定义的变量无法通过name属性获取tf.Variable对象。

tf.get_Variable()的name参数则是唯一的识别标准,因此只要通过tf.get_variables(name='xx')获取得到的变量都是同一个变量。因此更方便于参数共享。而在重复使用时,一定要在代码中强调scope.reuse_variables(),否则系统将会报错。

因此,推荐不论在任何时候创建变量都使用tf.get_variable(),从而可以在任何地方对他进行共享。

tf.name_scope() 与 tf.variable_scope():

tf.name_scope()可以简单的理解为为了更好的管理命名空间的方法。且只会影响tf.Variable()定义的变量的name。
tf.variable_scope()则可以影响到tf.get_variable()创建的对象的name。因此可以与tf.get_variable()一同使用,完成变量共享的目的,同时对命名空间进行管理。

tf.layers 参数复用

例如tf.layers.dense(), tf.layers.conv2D()等,参数复用只需要再tf.layers.dense(x, 4, name='h1', reuse=True),使得参数reuse为True,即可复用上一层的参数。
为了验证,我们可以通过:

x = tf.ones((1, 3))
y1 = tf.layers.dense(x, 4, name='h1')
y2 = tf.layers.dense(x, 4, name='h1', reuse=True)

# y1 and y2 will evaluate to the same values
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(y1))
print(sess.run(y2))  # both prints will return the same values

观察输出变量是否相同的方式来判断两个dense层是否共享参数。

siamese LSTM networks(暹罗LSTM网络结构)

以QA问题、答案对匹配为例:
为实现暹罗神经网络,我们需要使用共享参数的LSTM网络分别对问题和答案提取特征。
不同于tf.nn.dense,由于tf.nn.dynamic_rnn等是基于tf.nn.rnn_cell.LSTMCell() 构造得到的网络结构,因此只需要让dynamic_rnn(cell)中的cell输入为同一个LSTMCell即可:

utterance_gru = tf.nn.rnn_cell.LSTMCell(self.rnn_units, initializer=tf.orthogonal_initializer(), 
                state_is_tuple=True,)
_, utterance_gru_embeddings = tf.nn.dynamic_rnn(utterance_gru, all_utterance_embeddings, 
                sequence_length=self.utterance_len_ph,dtype=tf.float32, scope='utterance_rnn')
utterance_gru_embeddings = utterance_gru_embeddings[1]  
 _, response_gru_embeddings= tf.nn.dynamic_rnn(utterance_gru, response_embeddings, 
               sequence_length=self.utterance_len_ph,dtype=tf.float32, scope='response_rnn')
response_gru_embeddings = response_gru_embeddings[1]
self.utt = utterance_gru_embeddings[0]
self.res = response_gru_embeddings[0] 
# if all_utterance_embeddings == response_embeddings, self.utt == self.res


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM