tensorflow 变量共享涉及到几个常用的方法,tf.get_variable, tf.variable_scope, tf.reuse_variables等
为了记忆各个方法的功能,与其他方法做一个对比。
tf.variable 与 tf.get_variable
tensorflow中有两种方法生成变量variable, 一种是tf.get_variable(), 另一种是tf.Variable()。
tf.Variable() 在定义name 相同的变量时,为了不重复变量名,会自动给变量赋一个区别于前一个变量的名字(末尾_1等)。因此使用tf.Variable() 定义的变量无法通过name属性获取tf.Variable对象。
tf.get_Variable()的name参数则是唯一的识别标准,因此只要通过tf.get_variables(name='xx')获取得到的变量都是同一个变量。因此更方便于参数共享。而在重复使用时,一定要在代码中强调scope.reuse_variables(),否则系统将会报错。
因此,推荐不论在任何时候创建变量都使用tf.get_variable(),从而可以在任何地方对他进行共享。
tf.name_scope() 与 tf.variable_scope():
tf.name_scope()可以简单的理解为为了更好的管理命名空间的方法。且只会影响tf.Variable()定义的变量的name。
tf.variable_scope()则可以影响到tf.get_variable()创建的对象的name。因此可以与tf.get_variable()一同使用,完成变量共享的目的,同时对命名空间进行管理。
tf.layers 参数复用
例如tf.layers.dense(), tf.layers.conv2D()等,参数复用只需要再tf.layers.dense(x, 4, name='h1', reuse=True),使得参数reuse为True,即可复用上一层的参数。
为了验证,我们可以通过:
x = tf.ones((1, 3))
y1 = tf.layers.dense(x, 4, name='h1')
y2 = tf.layers.dense(x, 4, name='h1', reuse=True)
# y1 and y2 will evaluate to the same values
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(y1))
print(sess.run(y2)) # both prints will return the same values
观察输出变量是否相同的方式来判断两个dense层是否共享参数。
siamese LSTM networks(暹罗LSTM网络结构)
以QA问题、答案对匹配为例:
为实现暹罗神经网络,我们需要使用共享参数的LSTM网络分别对问题和答案提取特征。
不同于tf.nn.dense,由于tf.nn.dynamic_rnn等是基于tf.nn.rnn_cell.LSTMCell() 构造得到的网络结构,因此只需要让dynamic_rnn(cell)中的cell输入为同一个LSTMCell即可:
utterance_gru = tf.nn.rnn_cell.LSTMCell(self.rnn_units, initializer=tf.orthogonal_initializer(),
state_is_tuple=True,)
_, utterance_gru_embeddings = tf.nn.dynamic_rnn(utterance_gru, all_utterance_embeddings,
sequence_length=self.utterance_len_ph,dtype=tf.float32, scope='utterance_rnn')
utterance_gru_embeddings = utterance_gru_embeddings[1]
_, response_gru_embeddings= tf.nn.dynamic_rnn(utterance_gru, response_embeddings,
sequence_length=self.utterance_len_ph,dtype=tf.float32, scope='response_rnn')
response_gru_embeddings = response_gru_embeddings[1]
self.utt = utterance_gru_embeddings[0]
self.res = response_gru_embeddings[0]
# if all_utterance_embeddings == response_embeddings, self.utt == self.res