新版seqseq接口說明


attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=FLAGS.rnn_hidden_size, memory = encoder_outputs, memory_sequence_length = encoder_sequence_length)

這一步創造一個attention_mechanism。通過__call__(self, query, previous_alignments)來調用,輸入query也就是decode hidden,輸入previous_alignments是encode hidden,輸出是一個attention概率矩陣

 

helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(inputs, tf.to_int32(sequence_length), emb, tf.constant(FLAGS.scheduled_sampling_probability))

創建一個helper,用來處理每個時刻的輸入和輸出

 

my_decoder = tf.contrib.seq2seq.BasicDecoder(cell = cell, helper = helper, initial_state = state)

調用的核心部分。通過def step(self, time, inputs, state, name=None)來控制每一個進行decode

首先把inputs和attention進行concat作為輸入。(為什么這樣做,參考LSTM的實現 W1U+W2V,其實是把U,V concat在乘以一個W),那么這里inputs就是U,attention就是V(其實tf.concat(query,attention矩陣 * memory)在做個outpreject)。

 

outputs, state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope='seq_decode')

最后通過dynamic_decode來控制整個flow

 

 

寫到前面:

先看:

class BasicRNNCell(RNNCell):

def call(self, inputs, state):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
if self._linear is None:
self._linear = _Linear([inputs, state], self._num_units, True)

這個是核心,也就是W * input + U * state + B的實現,tf是用_Linear來實現的(_Linear的實現就是把input和state進行concat,然后乘以一個W)。由於rnn只有hidden,所以這里的state就是hidden

 

再看

class BasicLSTMCell(RNNCell):

if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

if self._linear is None:
self._linear = _Linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)

new_c = (
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)

if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state

就非常明顯了,由於lstm的state是由兩部分構成的,一個是hidden,一個是state,第一步先split。之后用inputs和h進行linear,由於我們要輸出4個結果,記得輸出維度一定要是4*_num_units。然后根據公式再進行后面的操作,最后返回新的hidden和state,也很直觀。

之后再看,加入attention之后怎么弄:

我們這里的attention為encode hidden,那么根據公式是attention和decode hidden進行concat作為一個大的hidden,之后和inputs一起進入網絡。

但是,tf實現的時候是這樣子的,首先把attention和inputs進行concat,之后把連接的結果作為inputs和decode hidden一起送入網絡。為什么能這么做呢,是因為在網絡內部其實也是concat之后再linear,參考上面的BasicLSTMCell實現,所有關鍵就是把(inputs,attention,decode hidden)concat一起就行了,不管順序是啥。說道這里你終於明白了AttentionWrapper到底是干啥的了。那么attention怎么計算呢,有個_compute_attention函數。我感覺就是非常直接了,attention_mechanism是你需要的attention映射矩陣的方式,

def _compute_attention(attention_mechanism, cell_output, previous_alignments,
attention_layer):
"""Computes the attention and alignments for a given attention_mechanism."""
alignments = attention_mechanism(
cell_output, previous_alignments=previous_alignments)

# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the
# memory time dimension.
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, memory_size]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, memory_size].
# we then squeeze out the singleton dim.
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1])

if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1))
else:
attention = context

return attention, alignments


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM