python3.6,tensorflow1.11
測試代碼:
tensorflow在eager模式下進行測試,方便調試,查看中間結果
1 import tensorflow as tf 2 3 tf.enable_eager_execution() 4 5 batch_size = 4 6 input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32) 7 cell = tf.nn.rnn_cell.BasicLSTMCell(10, forget_bias=1.0, state_is_tuple=True) 8 init_state = cell.zero_state(batch_size, dtype=tf.float32) 9 seq_length = tf.constant([2,3,2,3],dtype=tf.int32) 10 import pdb; pdb.set_trace() 11 output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state,sequence_length=seq_length,time_major=True) #time_major如果是True,就表示RNN的steps用第一個維度表示,建議用這個,運行速度快一點。 12 #如果是False,那么輸入的第二個維度就是steps。 13 #如果是True,output的維度是[steps, batch_size, depth],反之就是[batch_size, max_time, depth]。就是和輸入是一樣的 14 #final_state就是整個LSTM輸出的最終的狀態,包含c和h。c和h的維度都是[batch_size, n_hidden]
tf.nn.dynamic_rnn在tensorflow/python/ops/rnn.py中定義,進入其中調試
1 @tf_export("nn.dynamic_rnn") 2 def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, 3 dtype=None, parallel_iterations=None, swap_memory=False, 4 time_major=False, scope=None): 5 """Creates a recurrent neural network specified by RNNCell `cell`. 6 7 Performs fully dynamic unrolling of `inputs`. 8 9 Example: 10 11 ```python 12 # create a BasicRNNCell 13 rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) 14 15 # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] 16 17 # defining initial state 18 initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 19 20 # 'state' is a tensor of shape [batch_size, cell_state_size] 21 outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data, 22 initial_state=initial_state, 23 dtype=tf.float32) 24 ``` 25 26 ```python 27 # create 2 LSTMCells 28 rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 29 30 # create a RNN cell composed sequentially of a number of RNNCells 31 multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) 32 33 # 'outputs' is a tensor of shape [batch_size, max_time, 256] 34 # 'state' is a N-tuple where N is the number of LSTMCells containing a 35 # tf.contrib.rnn.LSTMStateTuple for each cell 36 outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, 37 inputs=data, 38 dtype=tf.float32) 39 ``` 40 41 42 Args: 43 cell: An instance of RNNCell. 44 inputs: The RNN inputs. 45 If `time_major == False` (default), this must be a `Tensor` of shape: 46 `[batch_size, max_time, ...]`, or a nested tuple of such 47 elements. 48 If `time_major == True`, this must be a `Tensor` of shape: 49 `[max_time, batch_size, ...]`, or a nested tuple of such 50 elements. 51 This may also be a (possibly nested) tuple of Tensors satisfying 52 this property. The first two dimensions must match across all the inputs, 53 but otherwise the ranks and other shape components may differ. 54 In this case, input to `cell` at each time-step will replicate the 55 structure of these tuples, except for the time dimension (from which the 56 time is taken). 57 The input to `cell` at each time step will be a `Tensor` or (possibly 58 nested) tuple of Tensors each with dimensions `[batch_size, ...]`. 59 sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. 60 Used to copy-through state and zero-out outputs when past a batch 61 element's sequence length. So it's more for performance than correctness. 62 initial_state: (optional) An initial state for the RNN. 63 If `cell.state_size` is an integer, this must be 64 a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 65 If `cell.state_size` is a tuple, this should be a tuple of 66 tensors having shapes `[batch_size, s] for s in cell.state_size`. 67 dtype: (optional) The data type for the initial state and expected output. 68 Required if initial_state is not provided or RNN state has a heterogeneous 69 dtype. 70 parallel_iterations: (Default: 32). The number of iterations to run in 71 parallel. Those operations which do not have any temporal dependency 72 and can be run in parallel, will be. This parameter trades off 73 time for space. Values >> 1 use more memory but take less time, 74 while smaller values use less memory but computations take longer. 75 swap_memory: Transparently swap the tensors produced in forward inference 76 but needed for back prop from GPU to CPU. This allows training RNNs 77 which would typically not fit on a single GPU, with very minimal (or no) 78 performance penalty. 79 time_major: The shape format of the `inputs` and `outputs` Tensors. 80 If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 81 If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 82 Using `time_major = True` is a bit more efficient because it avoids 83 transposes at the beginning and end of the RNN calculation. However, 84 most TensorFlow data is batch-major, so by default this function 85 accepts input and emits output in batch-major form. 86 scope: VariableScope for the created subgraph; defaults to "rnn". 87 88 Returns: 89 A pair (outputs, state) where: 90 91 outputs: The RNN output `Tensor`. 92 93 If time_major == False (default), this will be a `Tensor` shaped: 94 `[batch_size, max_time, cell.output_size]`. 95 96 If time_major == True, this will be a `Tensor` shaped: 97 `[max_time, batch_size, cell.output_size]`. 98 99 Note, if `cell.output_size` is a (possibly nested) tuple of integers 100 or `TensorShape` objects, then `outputs` will be a tuple having the 101 same structure as `cell.output_size`, containing Tensors having shapes 102 corresponding to the shape data in `cell.output_size`. 103 104 state: The final state. If `cell.state_size` is an int, this 105 will be shaped `[batch_size, cell.state_size]`. If it is a 106 `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 107 If it is a (possibly nested) tuple of ints or `TensorShape`, this will 108 be a tuple having the corresponding shapes. If cells are `LSTMCells` 109 `state` will be a tuple containing a `LSTMStateTuple` for each cell. 110 111 Raises: 112 TypeError: If `cell` is not an instance of RNNCell. 113 ValueError: If inputs is None or an empty list. 114 """ 115 rnn_cell_impl.assert_like_rnncell("cell", cell) 116 117 with vs.variable_scope(scope or "rnn") as varscope: 118 # Create a new scope in which the caching device is either 119 # determined by the parent scope, or is set to place the cached 120 # Variable using the same placement as for the rest of the RNN. 121 if _should_cache(): 122 if varscope.caching_device is None: 123 varscope.set_caching_device(lambda op: op.device) 124 125 # By default, time_major==False and inputs are batch-major: shaped 126 # [batch, time, depth] 127 # For internal calculations, we transpose to [time, batch, depth] 128 flat_input = nest.flatten(inputs) 129 130 if not time_major: 131 # (B,T,D) => (T,B,D) 132 flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] 133 flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) 134 135 parallel_iterations = parallel_iterations or 32 136 if sequence_length is not None: 137 sequence_length = math_ops.to_int32(sequence_length) 138 if sequence_length.get_shape().ndims not in (None, 1): 139 raise ValueError( 140 "sequence_length must be a vector of length batch_size, " 141 "but saw shape: %s" % sequence_length.get_shape()) 142 sequence_length = array_ops.identity( # Just to find it in the graph. 143 sequence_length, name="sequence_length") 144 145 batch_size = _best_effort_input_batch_size(flat_input) 146 147 if initial_state is not None: 148 state = initial_state 149 else: 150 if not dtype: 151 raise ValueError("If there is no initial_state, you must give a dtype.") 152 if getattr(cell, "get_initial_state", None) is not None: 153 state = cell.get_initial_state( 154 inputs=None, batch_size=batch_size, dtype=dtype) 155 else: 156 state = cell.zero_state(batch_size, dtype) 157 158 def _assert_has_shape(x, shape): 159 x_shape = array_ops.shape(x) 160 packed_shape = array_ops.stack(shape) 161 return control_flow_ops.Assert( 162 math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), 163 ["Expected shape for Tensor %s is " % x.name, 164 packed_shape, " but saw shape: ", x_shape]) 165 166 if not context.executing_eagerly() and sequence_length is not None: 167 # Perform some shape validation 168 with ops.control_dependencies( 169 [_assert_has_shape(sequence_length, [batch_size])]): 170 sequence_length = array_ops.identity( 171 sequence_length, name="CheckSeqLen") 172 173 inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 174 175 (outputs, final_state) = _dynamic_rnn_loop( 176 cell, 177 inputs, 178 state, 179 parallel_iterations=parallel_iterations, 180 swap_memory=swap_memory, 181 sequence_length=sequence_length, 182 dtype=dtype) 183 184 # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 185 # If we are performing batch-major calculations, transpose output back 186 # to shape [batch, time, depth] 187 if not time_major: 188 # (T,B,D) => (B,T,D) 189 outputs = nest.map_structure(_transpose_batch_time, outputs) 190 191 return (outputs, final_state)
最后調用_dynamic_rnn_loop
1 def _dynamic_rnn_loop(cell, 2 inputs, 3 initial_state, 4 parallel_iterations, 5 swap_memory, 6 sequence_length=None, 7 dtype=None): 8 """Internal implementation of Dynamic RNN. 9 10 Args: 11 cell: An instance of RNNCell. 12 inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested 13 tuple of such elements. 14 initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if 15 `cell.state_size` is a tuple, then this should be a tuple of 16 tensors having shapes `[batch_size, s] for s in cell.state_size`. 17 parallel_iterations: Positive Python int. 18 swap_memory: A Python boolean 19 sequence_length: (optional) An `int32` `Tensor` of shape [batch_size]. 20 dtype: (optional) Expected dtype of output. If not specified, inferred from 21 initial_state. 22 23 Returns: 24 Tuple `(final_outputs, final_state)`. 25 final_outputs: 26 A `Tensor` of shape `[time, batch_size, cell.output_size]`. If 27 `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape` 28 objects, then this returns a (possibly nested) tuple of Tensors matching 29 the corresponding shapes. 30 final_state: 31 A `Tensor`, or possibly nested tuple of Tensors, matching in length 32 and shapes to `initial_state`. 33 Raises: 34 ValueError: If the input depth cannot be inferred via shape inference 35 from the inputs. 36 """ 37 import pdb;pdb.set_trace() 38 state = initial_state 39 assert isinstance(parallel_iterations, int), "parallel_iterations must be int" 40 41 state_size = cell.state_size#LSTMStateTuple(c=10, h=10) 42 43 flat_input = nest.flatten(inputs)#list,~[0].shape=TensorShape([Dimension(3), Dimension(4), Dimension(6)]) 44 flat_output_size = nest.flatten(cell.output_size)#[10] 45 46 # Construct an initial output 47 input_shape = array_ops.shape(flat_input[0])#array([3, 4, 6] 48 time_steps = input_shape[0]#3 49 batch_size = _best_effort_input_batch_size(flat_input)#4 50 51 inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3) 52 for input_ in flat_input)#(TensorShape([Dimension(3), Dimension(4), Dimension(6)]),) 53 54 const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]#3,4 55 56 for shape in inputs_got_shape: 57 if not shape[2:].is_fully_defined(): 58 raise ValueError( 59 "Input size (depth of inputs) must be accessible via shape inference," 60 " but saw value None.") 61 got_time_steps = shape[0].value#3 62 got_batch_size = shape[1].value#4 63 if const_time_steps != got_time_steps: 64 raise ValueError( 65 "Time steps is not the same for all the elements in the input in a " 66 "batch.") 67 if const_batch_size != got_batch_size: 68 raise ValueError( 69 "Batch_size is not the same for all the elements in the input.") 70 71 # Prepare dynamic conditional copying of state & output 72 def _create_zero_arrays(size): 73 size = _concat(batch_size, size) 74 return array_ops.zeros( 75 array_ops.stack(size), _infer_state_dtype(dtype, state)) 76 77 flat_zero_output = tuple(_create_zero_arrays(output) 78 for output in flat_output_size)#tuple,~[0].shape:TensorShape([Dimension(4), Dimension(10)]) 79 zero_output = nest.pack_sequence_as(structure=cell.output_size, 80 flat_sequence=flat_zero_output)#TensorShape([Dimension(4), Dimension(10)]) 81 82 if sequence_length is not None: 83 min_sequence_length = math_ops.reduce_min(sequence_length)#2 84 max_sequence_length = math_ops.reduce_max(sequence_length)#3 85 else: 86 max_sequence_length = time_steps 87 88 time = array_ops.constant(0, dtype=dtypes.int32, name="time") 89 90 with ops.name_scope("dynamic_rnn") as scope: 91 base_name = scope 92 93 def _create_ta(name, element_shape, dtype): 94 return tensor_array_ops.TensorArray(dtype=dtype, 95 size=time_steps, 96 element_shape=element_shape, 97 tensor_array_name=base_name + name) 98 99 in_graph_mode = not context.executing_eagerly() 100 if in_graph_mode: 101 output_ta = tuple( 102 _create_ta( 103 "output_%d" % i, 104 element_shape=(tensor_shape.TensorShape([const_batch_size]) 105 .concatenate( 106 _maybe_tensor_shape_from_tensor(out_size))), 107 dtype=_infer_state_dtype(dtype, state)) 108 for i, out_size in enumerate(flat_output_size)) 109 input_ta = tuple( 110 _create_ta( 111 "input_%d" % i, 112 element_shape=flat_input_i.shape[1:], 113 dtype=flat_input_i.dtype) 114 for i, flat_input_i in enumerate(flat_input)) 115 input_ta = tuple(ta.unstack(input_) 116 for ta, input_ in zip(input_ta, flat_input)) 117 else: 118 output_ta = tuple([0 for _ in range(time_steps.numpy())] 119 for i in range(len(flat_output_size)))#([0, 0, 0],) 120 input_ta = flat_input##list,~[0].shape=TensorShape([Dimension(3), Dimension(4), Dimension(6)]) 121 122 def _time_step(time, output_ta_t, state): 123 """Take a time step of the dynamic RNN. 124 125 Args: 126 time: int32 scalar Tensor. 127 output_ta_t: List of `TensorArray`s that represent the output. 128 state: nested tuple of vector tensors that represent the state. 129 130 Returns: 131 The tuple (time + 1, output_ta_t with updated flow, new_state). 132 """ 133 import pdb;pdb.set_trace() 134 if in_graph_mode: 135 input_t = tuple(ta.read(time) for ta in input_ta) 136 # Restore some shape information 137 for input_, shape in zip(input_t, inputs_got_shape): 138 input_.set_shape(shape[1:]) 139 else: 140 input_t = tuple(ta[time.numpy()] for ta in input_ta)3#TensorShape([Dimension(4), Dimension(6)]) 141 142 input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)#TensorShape([Dimension(4), Dimension(6)]) 143 # Keras RNN cells only accept state as list, even if it's a single tensor. 144 is_keras_rnn_cell = _is_keras_rnn_cell(cell) 145 if is_keras_rnn_cell and not nest.is_sequence(state): 146 state = [state] 147 call_cell = lambda: cell(input_t, state) 148 149 if sequence_length is not None: 150 (output, new_state) = _rnn_step( 151 time=time, 152 sequence_length=sequence_length, 153 min_sequence_length=min_sequence_length, 154 max_sequence_length=max_sequence_length, 155 zero_output=zero_output, 156 state=state, 157 call_cell=call_cell, 158 state_size=state_size, 159 skip_conditionals=True) 160 else: 161 (output, new_state) = call_cell() 162 163 # Keras cells always wrap state as list, even if it's a single tensor. 164 if is_keras_rnn_cell and len(new_state) == 1: 165 new_state = new_state[0] 166 # Pack state if using state tuples 167 output = nest.flatten(output) 168 169 if in_graph_mode: 170 output_ta_t = tuple( 171 ta.write(time, out) for ta, out in zip(output_ta_t, output)) 172 else: 173 for ta, out in zip(output_ta_t, output): 174 ta[time.numpy()] = out 175 176 return (time + 1, output_ta_t, new_state) 177 178 if in_graph_mode: 179 # Make sure that we run at least 1 step, if necessary, to ensure 180 # the TensorArrays pick up the dynamic shape. 181 loop_bound = math_ops.minimum( 182 time_steps, math_ops.maximum(1, max_sequence_length)) 183 else: 184 # Using max_sequence_length isn't currently supported in the Eager branch. 185 loop_bound = time_steps#3 186 187 _, output_final_ta, final_state = control_flow_ops.while_loop( 188 cond=lambda time, *_: time < loop_bound, 189 body=_time_step, 190 loop_vars=(time, output_ta, state), 191 parallel_iterations=parallel_iterations, 192 maximum_iterations=time_steps, 193 swap_memory=swap_memory) 194 195 # Unpack final output if not using output tuples. 196 if in_graph_mode: 197 final_outputs = tuple(ta.stack() for ta in output_final_ta) 198 # Restore some shape information 199 for output, output_size in zip(final_outputs, flat_output_size): 200 shape = _concat( 201 [const_time_steps, const_batch_size], output_size, static=True) 202 output.set_shape(shape) 203 else: 204 final_outputs = output_final_ta 205 206 final_outputs = nest.pack_sequence_as( 207 structure=cell.output_size, flat_sequence=final_outputs) 208 if not in_graph_mode: 209 final_outputs = nest.map_structure_up_to( 210 cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs) 211 212 return (final_outputs, final_state)
可以看到dynamic_rnn主要是利用while_loop處理不同Batch長度不同的問題
從上面82-86行看出,如果不給sequence_length參數,sequence_length=time_step=input.shape[0],當給定參數sequence_length時,調用_rnn_step函數,對超出長度的部分output設0,這一點在下面代碼60,70行實現
1 def _rnn_step( 2 time, sequence_length, min_sequence_length, max_sequence_length, 3 zero_output, state, call_cell, state_size, skip_conditionals=False): 4 """Calculate one step of a dynamic RNN minibatch. 5 6 Returns an (output, state) pair conditioned on `sequence_length`. 7 When skip_conditionals=False, the pseudocode is something like: 8 9 if t >= max_sequence_length: 10 return (zero_output, state) 11 if t < min_sequence_length: 12 return call_cell() 13 14 # Selectively output zeros or output, old state or new state depending 15 # on whether we've finished calculating each row. 16 new_output, new_state = call_cell() 17 final_output = np.vstack([ 18 zero_output if time >= sequence_length[r] else new_output_r 19 for r, new_output_r in enumerate(new_output) 20 ]) 21 final_state = np.vstack([ 22 state[r] if time >= sequence_length[r] else new_state_r 23 for r, new_state_r in enumerate(new_state) 24 ]) 25 return (final_output, final_state) 26 27 Args: 28 time: int32 `Tensor` scalar. 29 sequence_length: int32 `Tensor` vector of size [batch_size]. 30 min_sequence_length: int32 `Tensor` scalar, min of sequence_length. 31 max_sequence_length: int32 `Tensor` scalar, max of sequence_length. 32 zero_output: `Tensor` vector of shape [output_size]. 33 state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, 34 or a list/tuple of such tensors. 35 call_cell: lambda returning tuple of (new_output, new_state) where 36 new_output is a `Tensor` matrix of shape `[batch_size, output_size]`. 37 new_state is a `Tensor` matrix of shape `[batch_size, state_size]`. 38 state_size: The `cell.state_size` associated with the state. 39 skip_conditionals: Python bool, whether to skip using the conditional 40 calculations. This is useful for `dynamic_rnn`, where the input tensor 41 matches `max_sequence_length`, and using conditionals just slows 42 everything down. 43 44 Returns: 45 A tuple of (`final_output`, `final_state`) as given by the pseudocode above: 46 final_output is a `Tensor` matrix of shape [batch_size, output_size] 47 final_state is either a single `Tensor` matrix, or a tuple of such 48 matrices (matching length and shapes of input `state`). 49 50 Raises: 51 ValueError: If the cell returns a state tuple whose length does not match 52 that returned by `state_size`. 53 """ 54 import pdb;pdb.set_trace() 55 # Convert state to a list for ease of use 56 flat_state = nest.flatten(state)#[c,h],shape=[4,10] 57 flat_zero_output = nest.flatten(zero_output)#list,~[0].shape:TensorShape([Dimension(4), Dimension(10)]) 58 59 # Vector describing which batch entries are finished. 60 copy_cond = time >= sequence_length#step1:array([False, False, False, False]) 61 62 def _copy_one_through(output, new_output): 63 # TensorArray and scalar get passed through. 64 if isinstance(output, tensor_array_ops.TensorArray): 65 return new_output 66 if output.shape.ndims == 0: 67 return new_output 68 # Otherwise propagate the old or the new value. 69 with ops.colocate_with(new_output): 70 return array_ops.where(copy_cond, output, new_output)#多余的取0 71 72 def _copy_some_through(flat_new_output, flat_new_state): 73 # Use broadcasting select to determine which values should get 74 # the previous state & zero output, and which values should get 75 # a calculated state & output. 76 flat_new_output = [ 77 _copy_one_through(zero_output, new_output) 78 for zero_output, new_output in zip(flat_zero_output, flat_new_output)] 79 flat_new_state = [ 80 _copy_one_through(state, new_state) 81 for state, new_state in zip(flat_state, flat_new_state)] 82 return flat_new_output + flat_new_state 83 84 def _maybe_copy_some_through(): 85 """Run RNN step. Pass through either no or some past state.""" 86 new_output, new_state = call_cell() 87 88 nest.assert_same_structure(state, new_state) 89 90 flat_new_state = nest.flatten(new_state) 91 flat_new_output = nest.flatten(new_output) 92 return control_flow_ops.cond( 93 # if t < min_seq_len: calculate and return everything 94 time < min_sequence_length, lambda: flat_new_output + flat_new_state, 95 # else copy some of it through 96 lambda: _copy_some_through(flat_new_output, flat_new_state)) 97 98 # TODO(ebrevdo): skipping these conditionals may cause a slowdown, 99 # but benefits from removing cond() and its gradient. We should 100 # profile with and without this switch here. 101 if skip_conditionals: 102 # Instead of using conditionals, perform the selective copy at all time 103 # steps. This is faster when max_seq_len is equal to the number of unrolls 104 # (which is typical for dynamic_rnn). 105 new_output, new_state = call_cell() 106 nest.assert_same_structure(state, new_state) 107 new_state = nest.flatten(new_state)#[c,h],shape=(4, 10) 108 new_output = nest.flatten(new_output)#shape=(4, 10) 109 final_output_and_state = _copy_some_through(new_output, new_state) 110 else: 111 empty_update = lambda: flat_zero_output + flat_state 112 final_output_and_state = control_flow_ops.cond( 113 # if t >= max_seq_len: copy all state through, output zeros 114 time >= max_sequence_length, empty_update, 115 # otherwise calculation is required: copy some or all of it through 116 _maybe_copy_some_through) 117 118 if len(final_output_and_state) != len(flat_zero_output) + len(flat_state): 119 raise ValueError("Internal error: state and output were not concatenated " 120 "correctly.") 121 final_output = final_output_and_state[:len(flat_zero_output)] 122 final_state = final_output_and_state[len(flat_zero_output):] 123 124 for output, flat_output in zip(final_output, flat_zero_output): 125 output.set_shape(flat_output.get_shape()) 126 for substate, flat_substate in zip(final_state, flat_state): 127 if not isinstance(substate, tensor_array_ops.TensorArray): 128 substate.set_shape(flat_substate.get_shape()) 129 130 final_output = nest.pack_sequence_as( 131 structure=zero_output, flat_sequence=final_output) 132 final_state = nest.pack_sequence_as( 133 structure=state, flat_sequence=final_state) 134 135 return final_output, final_state