一、tf.nn.dynamic_rnn :函数使用和输出
官网:https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn
使用说明:
Args:
cell
:
An instance of RNNCell. //自己定义的cell 内容:BasicLSTMCell,BasicRNNCell,GRUCell 等,,,
- inputs:
If time_major == False
(default), this must be a Tensor
of shape: [batch_size, max_time, ...]
, or a nested tuple of such elements.
If time_major == True
, this must be a Tensor
of shape:[max_time, batch_size, ...]
, or a nested tuple of such elements.
#如果是time_major=True,input的维度是[max_time, batch_size, input_size],反之就是[batch_size, max_time, input_zise];
time_major
:
The shape format of the inputs
and outputs
Tensors. If true, these Tensors
must be shaped [max_time, batch_size, depth]
. If false, these Tensors
must be shaped [batch_size, max_time, depth]
.
Using time_major = True
is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form.
return:
- outputs: The RNN output Tensor.
If time_major == False (default), this will be a Tensor shaped: [batch_size, max_time, cell.output_size].
If time_major == True, this will be a Tensor shaped: [max_time, batch_size, cell.output_size].
#如果是time_major=True,output的维度是[max_time, batch_size, Hidden_size],反之就是[batch_size, max_time, Hidden_size];
- state: The final state.
If cell.state_size is an int, this will be shaped [batch_size, cell.state_size].
If it is a TensorShape, this will be shaped [batch_size] + cell.state_size. If it is a (possibly nested) tuple of ints or TensorShape, this will be a tuple having the corresponding shapes.
If cells are LSTMCells state will be a tuple containing a LSTM StateTuple for each cell.
返回值:
然后经过LSTM或者GRU产生两个输出,output和state;
output里面,包含了所有时刻的输出 H;
如果cell 是 LSTM:
state里面,单层rnn的时候包含了最后一个时刻的输出 H 和 C; 多层的时候,包含,所有层的 H 和 C;
如果cell 是 GRU:
state里面,单层rnn的时候包含了最后一个时刻的输出 H;多层的时候,包含所有层的H;
#注意,执行两次run的时候相当于执行了两次图,所以得到的结果不同;
示例代码:
import tensorflow as tf batch_size = 4 input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32) cell = tf.nn.rnn_cell.BasicLSTMCell(10, forget_bias=1.0, state_is_tuple=True) init_state = cell.zero_state(batch_size, dtype=tf.float32) output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state, time_major=True) #time_major如果是True,就表示RNN的steps用第一个维度表示,建议用这个,运行速度快一点。 #如果是False,那么输入的第二个维度就是steps。 #如果是True,output的维度是[steps, batch_size, depth],反之就是[batch_size, max_time, depth]。就是和输入是一样的 #final_state就是整个LSTM输出的最终的状态,包含c和h。c和h的维度都是[batch_size, n_hidden] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) #print(sess.run(output)) #print(sess.run(final_state)) print(sess.run([output,final_state]))
输出比较:
由输出可以看出来,这个output 包含了隐含层所有时刻的输出,如果加层的话,那么这个output 的每个时刻,就作为下一层每个时刻的输入;
LSTM 中 state的输出包含了C和H,两者代表的都是当前层的最后一时刻的输出,H和output的最后一个时刻值一样;
输出如下:
[array([[[ 0.11201711, 0.05266698, 0.12750182, 0.03627545, 0.02706259, -0.11562401, 0.08477378, 0.06157489, 0.07341921, 0.15011263], [-0.09552816, -0.17314027, -0.0895469 , -0.26399866, -0.36308575, 0.10537394, -0.09443868, -0.31130335, 0.0132737 , -0.12810872], [-0.00719012, 0.04438379, -0.03804718, -0.06637346, -0.02082551, 0.132549 , -0.05982352, 0.11778613, -0.09206182, 0.02547247], [ 0.14723007, 0.05410767, 0.06571447, 0.06775881, -0.03286515, 0.31600857, 0.03567648, 0.10357846, -0.0679171 , -0.00785992]], [[ 0.06683166, -0.05883167, 0.10910213, 0.05030679, 0.17738451, 0.00631482, -0.00457612, -0.03694798, 0.17743434, 0.06658468], [-0.03385706, -0.20001511, -0.05247132, -0.14611273, -0.17433529, 0.14970839, -0.07725038, -0.32652032, 0.09670977, -0.17828827], [ 0.03988864, -0.03118243, -0.09246919, 0.1831698 , -0.01006366, 0.01672944, 0.01009638, 0.10943947, -0.00420897, -0.0054652 ], [ 0.16778645, 0.08699884, 0.12571299, 0.12276714, 0.04879797, 0.10517071, 0.10341848, 0.15480027, -0.04619027, 0.11167715]], [[ 0.14293307, -0.10649989, 0.09144076, -0.03020415, 0.18182378, 0.22111537, -0.02275194, -0.14586878, 0.19310513, -0.02283864], [-0.0553881 , -0.16710383, -0.09584018, -0.06020959, -0.11862611, 0.05812657, -0.05461238, -0.21729217, 0.08961426, -0.1420837 ], [ 0.03053934, 0.02213254, -0.11577073, 0.08933022, -0.08349261, 0.044699 , 0.01332499, 0.14753158, -0.12446564, 0.00095996], [ 0.21244884, 0.11677884, 0.15352076, 0.04703464, 0.07084017, 0.04610508, 0.09713535, 0.12495688, 0.00218641, 0.17711937]]], dtype=float32),
LSTMStateTuple(
c=array([[ 0.264239 , -0.16139928, 0.25842854, -0.05938458, 0.38918033, 0.37621742, -0.06394874, -0.263255 , 0.32704324, -0.04286532], [-0.11041687, -0.3316248 , -0.21551779, -0.12425021, -0.2452825 , 0.12507899, -0.11451716, -0.40844095, 0.20570038, -0.28551656], [ 0.0634905 , 0.05425977, -0.19805768, 0.15730162, -0.14432296, 0.09046975, 0.02406704, 0.34546444, -0.22364679, 0.00243504], [ 0.40725306, 0.25660557, 0.3873769 , 0.11941462, 0.16212168, 0.10613891, 0.1803763 , 0.26139545, 0.00540481, 0.31761324]], dtype=float32),
h=array([[ 0.14293307, -0.10649989, 0.09144076, -0.03020415, 0.18182378, 0.22111537, -0.02275194, -0.14586878, 0.19310513, -0.02283864], [-0.0553881 , -0.16710383, -0.09584018, -0.06020959, -0.11862611, 0.05812657, -0.05461238, -0.21729217, 0.08961426, -0.1420837 ], [ 0.03053934, 0.02213254, -0.11577073, 0.08933022, -0.08349261, 0.044699 , 0.01332499, 0.14753158, -0.12446564, 0.00095996], [ 0.21244884, 0.11677884, 0.15352076, 0.04703464, 0.07084017, 0.04610508, 0.09713535, 0.12495688, 0.00218641, 0.17711937]], dtype=float32))]
二、tf.nn.bidirectional_dynamic_rnn 的输出:
官网:https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn
示例代码在最下面:
函数输入定义:
tf.nn.bidirectional_dynamic_rnn( cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None )
输出包含两部分:A tuple (outputs,output_state)
include two tuple--->> ( (fw_outputs,bw_outputs),(fw_final_state,bw_final_state) )
输出的shape:
output, final_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw, inputs, time_major=True) final_output = tf.concat(outputs, 2) #如果只需要输出,不需要state: '''
output, final_state =
(
(fw_outputs,bw_outputs),
(fw_final_state,bw_final_state)
) = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw, inputs, time_major=True)
# 伪代码 输出shape if time_major == False(default) ( 不同的的设置shape 不一样); outputs is A tuple :(output_fw,output_bw)-->>shape ([batch_size,max_time,cell_fw.output_size],[batch_size,max_time,cell_fw.output_size]) output_state is A tuple :(output_state_fw,output_state_bw)-->> shape ([batch_size,max_time,cell_bw.output_size],[batch_size,max_time,cell_bw.output_size]) output_state_fw:shape -->> shape ([batch_size, cell_fw.state_size],[batch_size, cell_fw.state_size]) output_state_bw:shape -->> shape ([batch_size, cell_bw.state_size],[batch_size, cell_bw.state_size]) else: outputs is A tuple :(output_fw,output_bw)-->>shape ([max_time,batch_size,cell_fw.output_size],[max_time,batch_size,cell_fw.output_size]) output_state is A tuple :(output_state_fw,output_state_bw)-->> shape ([batch_size,max_time,cell_bw.output_size],[batch_size,max_time,cell_bw.output_size]) output_state_fw:shape -->> shape ([batch_size, cell_fw.state_size],[batch_size, cell_fw.state_size]) output_state_bw:shape -->> shape ([batch_size, cell_bw.state_size],[batch_size, cell_bw.state_size]) '''
测试<bidirectional_dynamic_rnn>的代码和代码的输出如下:
import tensorflow as tf
batch_size = 4 input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32) fw_cell = tf.nn.rnn_cell.BasicLSTMCell(10, forget_bias=1.0, state_is_tuple=True) bw_cell = tf.nn.rnn_cell.BasicLSTMCell(10, forget_bias=1.0, state_is_tuple=True) init_state = fw_cell.zero_state(batch_size, dtype=tf.float32) # bi bi dynamic rnn output, final_state = tf.nn.bidirectional_dynamic_rnn(fw_cell,bw_cell, input,dtype = tf.float32, time_major=True) #time_major如果是True,就表示RNN的steps用第一个维度表示,建议用这个,运行速度快一点。 # bi dynamic rnn #如果是False,那么输入的第二个维度就是steps。 #如果是True,output的维度是[max_time, batch_size, depth],反之就是[batch_size, max_time, depth]。就是和输入是一样的 #final_state就是整个LSTM输出的最终的状态,包含c和h。c和h的维度都是[batch_size, n_hidden] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) #print(sess.run(output)) #print(sess.run(final_state)) print(sess.run([output,final_state]))
输出比较:
由输出可以看出来,这个output 包含了隐含层所有时刻前向和后向的输出,如果加层的话,那么这个output 的每个时刻,就作为下一层每个时刻的输入;
注:GRU 和 LSTM 的 state 输出不同,但是都是tuple;
如果cell 是 LSTM:
state在单层rnn中,输出包含了最后一层的,前向和后向C和H,两者代表的都是最后层的最后一时刻的输出,前后向的H,对应的output的最后一个时刻值;
state在多层rnn中,输出包含了所有层的的,前向和后向C和H,两者代表的都是每一层的最后一时刻的输出,前后向的H,对应的output每一层的最后一个时刻值(output 只保存最后一层的输出,state 保存所有层的 H和C);
如果cell 是 GRU:
state在单层rnn中,输出包含了最后一层的,前向和后向的H,两者代表的都是最后层的最后一时刻的输出,前后向的H,对应的output的最后一个时刻值;
state在多层rnn中,输出包含了所有层的,前向和后向的H,H代表的是每一层的最后一时刻的输出,前后向的H,对应的output每一层的最后一个时刻值(output 只保存最后一层的输出,state 保存所有层的 H);
测试的输出:
[(array([[[ 0.14106834, -0.04528632, -0.02010263, -0.04126427,
-0.0332197 , -0.00106247, 0.13835369, -0.09588441, 0.03510131, 0.05169246], [ 0.06133147, -0.0232582 , -0.05024356, 0.09498103, 0.0086002 , -0.10530724, 0.18253753, -0.07942553, 0.07153056, 0.00886667], [ 0.14753884, -0.04079055, -0.01406685, 0.13480215, 0.03796537, -0.09043769, 0.2314299 , -0.08740149, 0.06912782, -0.01267859], [ 0.06526165, 0.01146618, -0.23638128, -0.05581092, -0.11822883, -0.16410625, -0.05027798, -0.08647367, -0.08076127, 0.04061119]], [[-0.08403834, 0.00833149, 0.03032847, 0.14006507, -0.04024323, 0.17318085, -0.13764451, -0.10830184, 0.12852332, -0.07746515], [-0.09468611, -0.01417435, -0.10501526, 0.18880928, -0.01667678, -0.17388701, -0.08902372, -0.12834042, 0.08464789, -0.06793084], [ 0.19070639, -0.01584721, -0.01068342, 0.02016271, 0.07228621, -0.10173664, 0.06089543, -0.02589199, -0.10167402, 0.0052557 ], [ 0.06274109, -0.00049266, -0.08755676, 0.09687787, -0.08074448, -0.25040433, -0.01494961, -0.19038169, -0.01472963, -0.03207526]], [[ 0.08066446, -0.08768935, 0.03135343, 0.096322 , -0.03832811, 0.11857887, 0.03208668, -0.2740993 , 0.15332787, -0.1054797 ], [-0.12622547, -0.01234146, -0.053387 , 0.17855036, 0.0279602 , -0.16127944, -0.12577067, -0.10588194, 0.1634436 , -0.05892743], [ 0.17111342, -0.01833202, -0.1570774 , 0.00630365, -0.06089298, -0.23136179, -0.08443928, -0.38194925, -0.20987682, -0.02470726], [ 0.02754376, -0.03641582, 0.06693614, 0.28510216, 0.01096512, -0.19997968, 0.01686571, -0.18401314, 0.13415188, -0.11835594]]], dtype=float32),
array([[[-4.71892580e-02, 1.24903403e-01, -1.30026698e-01, 1.35595098e-01, 1.79924071e-01, -2.22253352e-02, 8.44292119e-02, -3.16204876e-02, 1.23763248e-01, -1.41852632e-01], [ 1.13655411e-01, 2.14847624e-01, 1.42422020e-01, -2.18600005e-01, 1.82050109e-01, -1.60371423e-01, -1.40698269e-01, -1.89826444e-01, -1.82936639e-01, -1.77186489e-01], [-1.81219578e-01, -1.18457131e-01, -1.19064420e-01, 1.67499259e-01, 5.41930646e-03, 5.22245020e-02, -1.44038424e-01, -1.43634990e-01, 5.55445813e-02, -1.90527022e-01], [-7.35785663e-02, -7.54526583e-03, 7.17180362e-03, -2.84532481e-03, -9.26728696e-02, -1.73879474e-01, -1.37960657e-01, -1.92156255e-01, -6.91511333e-02, -1.93823442e-01]], [[ 2.35215947e-02, 2.12839022e-01, 1.19708128e-01, 9.60118312e-04, 5.18597253e-02, -1.06155628e-03, -1.04444381e-02, 4.71496470e-02, 9.75937350e-04, -5.14771827e-02], [ 4.95758429e-02, 2.43057415e-01, 2.12201729e-01, -1.59496874e-01, 2.85923164e-02, -1.15410425e-01, -9.19163823e-02, -5.79124615e-02, -1.70543984e-01, -7.04976842e-02], [-3.71870756e-01, -3.13842781e-02, -2.12511465e-01, 4.09433812e-01, -4.14451808e-01, -4.04715101e-04, -4.27415073e-02, 3.08825504e-02, 2.47091308e-01, -1.90000907e-01], [ 1.59666702e-01, -9.14119035e-02, 2.77409274e-02, -3.35936815e-01, 9.16254967e-02, -2.70533301e-02, -1.54404104e-01, -6.44175261e-02, -1.75929859e-01, -2.02384084e-01]], [[-4.12032045e-02, 4.64685149e-02, -1.03556208e-01, 1.14612505e-01, 1.44457981e-01, 4.20663096e-02, 4.07745019e-02, -1.11312367e-01, 5.78245446e-02, -8.26596469e-02], [ 7.64411911e-02, 9.69330519e-02, 8.21158811e-02, -1.07323572e-01, 9.91004631e-02, -5.17336503e-02, -4.19655181e-02, -1.74307302e-02, -5.12403548e-02, -4.96065766e-02], [-2.42707521e-01, 1.06276445e-01, -6.91626742e-02, 4.17545497e-01, -2.72490084e-01, -1.08139351e-01, -2.92627905e-02, -1.17304660e-01, 2.38732815e-01, -2.74660379e-01], [ 1.69804603e-01, -1.16702899e-01, 2.23863665e-02, -2.18844891e-01, 7.04917312e-02, 8.52173045e-02, -1.40099138e-01, 5.31176142e-02, -8.26682746e-02, -9.91851762e-02]]], dtype=float32)),
(LSTMStateTuple(c=array([[ 0.15401003, -0.1852682 , 0.07581142, 0.2094832 , -0.10182755, 0.25072512, 0.05835846, -0.44341922, 0.30883613, -0.22787577], [-0.24870059, -0.02492736, -0.11864171, 0.35445976, 0.06311206, -0.30257928, -0.21532504, -0.22805184, 0.35281387, -0.13972323], [ 0.26539534, -0.04326108, -0.21563095, 0.02033684, -0.22373867, -0.72543144, -0.23284285, -0.6351244 , -0.38619697, -0.07513228], [ 0.06854046, -0.08422969, 0.19689892, 0.5770679 , 0.02262445, -0.48846442, 0.04160193, -0.38153723, 0.2185685 , -0.40132868]], dtype=float32), h=array([[ 0.08066446, -0.08768935, 0.03135343, 0.096322 , -0.03832811, 0.11857887, 0.03208668, -0.2740993 , 0.15332787, -0.1054797 ], [-0.12622547, -0.01234146, -0.053387 , 0.17855036, 0.0279602 , -0.16127944, -0.12577067, -0.10588194, 0.1634436 , -0.05892743], [ 0.17111342, -0.01833202, -0.1570774 , 0.00630365, -0.06089298, -0.23136179, -0.08443928, -0.38194925, -0.20987682, -0.02470726], [ 0.02754376, -0.03641582, 0.06693614, 0.28510216, 0.01096512, -0.19997968, 0.01686571, -0.18401314, 0.13415188, -0.11835594]], dtype=float32)), LSTMStateTuple(c=array([[-0.0744614 , 0.26128462, -0.2831493 , 0.26449782, 0.31356773, -0.05635487, 0.18961594, -0.06494538, 0.32900015, -0.21728653], [ 0.27273336, 0.3635042 , 0.32611233, -0.40245444, 0.40640068, -0.3703317 , -0.35593987, -0.3990611 , -0.40341747, -0.39981088], [-0.34268516, -0.20780411, -0.30328298, 0.24868692, 0.00983843, 0.1082218 , -0.2859917 , -0.24331218, 0.13357104, -0.36969435], [-0.19263096, -0.01261891, 0.01626399, -0.00499403, -0.22496712, -0.32415336, -0.4642058 , -0.39442265, -0.10230929, -0.42570713]], dtype=float32), h=array([[-0.04718926, 0.1249034 , -0.1300267 , 0.1355951 , 0.17992407, -0.02222534, 0.08442921, -0.03162049, 0.12376325, -0.14185263], [ 0.11365541, 0.21484762, 0.14242202, -0.2186 , 0.18205011, -0.16037142, -0.14069827, -0.18982644, -0.18293664, -0.17718649], [-0.18121958, -0.11845713, -0.11906442, 0.16749926, 0.00541931, 0.0522245 , -0.14403842, -0.14363499, 0.05554458, -0.19052702], [-0.07357857, -0.00754527, 0.0071718 , -0.00284532, -0.09267287, -0.17387947, -0.13796066, -0.19215626, -0.06915113, -0.19382344]], dtype=float32)))]
三、encoder 的输出转化为 decoder 的输入;
如果decoder_num_layer 不等于encoder_num_layer,可以扩展encoder的输入,或加入全连接层,再生成多个 encoder_state;
补充注:在encoder2decoder模型中,如果单向rnn 作为decoder输入的时候;state 是直接作为输入的,不需要转化; 必须保证encoder_num_layer == decoder_num_layer
补充注:在encoder2decoder模型中,如果双向rnn 作为decoder输入的时候,
是需要把前向和后向结合在一起的:
示例代码如下:
转化为decoder输入的示例代码:
以下是双向rnn 进行decoder 的时候;
output, final_state = ( (encoder_fw_outputs,encoder_bw_outputs), (encoder_fw_final_state,encoder_bw_final_state) )= tf.nn.bidirectional_dynamic_rnn(cell, input, initial_state=init_state,time_major=True) #encoder 的结果作为 decoder 输入的时候,只需要 encoder 的最后一层的隐藏层(state); # bi_dynamic_rnnencoder 的output,双向rnn 的输入 encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2) # encoder_outputs shape :[batch_size,cell_fw.size + cell_bw.size] ''' 代码不能直接用 encoder and decoder is LSTM # 需要传入 decoder 的 encoder_state encoder_final_state_c = tf.concat( (encoder_fw_final_state.c, encoder_bw_final_state.c), 1) encoder_final_state_h = tf.concat( (encoder_fw_final_state.h, encoder_bw_final_state.h), 1) # encoder_final_state is a tuple shape: ([batch_size,cell_fw_c.state_size + cell_bw_c.state_size],[batch_size,cell_fw_h.state_size+cell_bw_h.state_size]) encoder_final_state = tf.contrib.rnn.LSTMStateTuple( c=encoder_final_state_c, h=encoder_final_state_h ) ''' # 以下的代码可直接使用,num_layers=encoder_layers,encoder 和decoder 层数的设置;必须保证层数如下;
#单向encoder的时候:encoder_layers_num == decoder_layers_num #双向encoder的时候: encoder_layers_num == decoder_layers_num // 2 encoder_state = [] for i in range(num_layers): encoder_state.append(bi_state[0][i]) # forward encoder_state.append(bi_state[1][i]) # backward encoder_state = tuple(encoder_state) # 2 tuple, 2 tuple(c & h), batch_size, hidden_size
return encoder_outputs,encoder_state