tensorflow.nn.bidirectional_dynamic_rnn()函数的用法


转自:http://blog.csdn.net/wuzqchom/article/details/75453327 

使用tensorflow.nn.bidirectional_dynamic_rnn()这个函数,就可以很方便的实现双向LSTM,很简洁。 
首先来看一下,函数:

def bidirectional_dynamic_rnn(
cell_fw, # 前向RNN
cell_bw, # 后向RNN
inputs, # 输入
sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
initial_state_fw=None,  # 前向的初始化状态(可选)
initial_state_bw=None,  # 后向的初始化状态(可选)
dtype=None, # 初始化和输出的数据类型(可选)
parallel_iterations=None,
swap_memory=False, 
time_major=False,
# 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`. 
# 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`. 
scope=None
)
返回值:
一个(outputs, output_states)的元组
其中,
1. outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设
time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。
2. output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。
output_state_fw和output_state_bw的类型为LSTMStateTuple。
LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。

而cell_fw和cell_bw的定义是完全一样的。如果这两个cell选LSTM cell整个结构就是双向LSTM了。

# lstm模型 正方向传播的RNN
lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(embedding_size, forget_bias=1.0)
# 反方向传播的RNN
lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(embedding_size, forget_bias=1.0)

但是看来看去,输入两个cell都是相同的啊? 
其实在bidirectional_dynamic_rnn函数的内部,会把反向传播的cell使用array_ops.reverse_sequence的函数将输入的序列逆序排列,使其可以达到反向传播的效果。 
在实现的时候,我们是需要传入两个cell作为参数就可以了:

(outputs, output_states) = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, embedded_chars,  dtype=tf.float32)

embedded_chars为输入的tensor,[batch_szie, max_time, depth]。batch_size为模型当中batch的大小,应用在文本中时,max_time可以为句子的长度(一般以最长的句子为准,短句需要做padding),depth为输入句子词向量的维度。

当然你也可以使用循环的方式,时长为句子的长度,每一次都以上一时刻(假设一词为句子的基本单位的话,即上一个词)的隐藏状态和当前时刻的tensor为输入,但是这样写的时候相对会比较麻烦,若使用bidirectional_dynamic_rnn()则会清爽很多。 

本篇仅仅是在应用接口层面介绍了bidirectional_dynamic_rnn,内部实现并没有做过多的探讨,dynamic_rnn()函数也有一些工程上的优化,比如加入buckets机制。 
具体解释见知乎问题:tensorflow中的seq2seq例子为什么需要bucket?贾杨清的回答。 
另外关于dynamic_rnn和普通的rnn区别可见另外一个 
知乎问题: tensor flow dynamic_rnn 与rnn有啥区别?


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM