tensorflow基於 Grammar as a Foreign Language實現,這篇論文給出的公式也比較清楚。
這里關注seq2seq.attention_decode函數,
- 主要輸入
decoder_inputs,
initial_state,
attention_states,
這里可以主要參考 models/textsum的應用,textsum采用的多層雙向lstm,
假設只有一層,texsum將正向 最后輸出的state作為 attention_decode的輸入initial_state
(不過很多論文認為用逆向最后的state可能效果更好)
對應decocer_inputs就是標注的摘要的字符序列id對應查找到的embedding序列
而attention_states是正向負向輸出concatenate的所有outputs(hidden注意output和hidden是等同概念)
- 關於linear
首先注意到在attention_decode函數用到了一個linear這個定義在rnn_cell._linear函數
他的輸入是 一個list 可能的輸入是比如
[ [batch_size, lenght1], [batch_size_length2]]
對應一個list 2個數組
它的作用是內部定義一個數組 對應這個例子 [length1 + length2, output_size]
也就是起到將[batch_size, length1][batch_size, length2]的序列輸入映射到 [batch_size, output_size]的輸出
這個在attention機制最后會遇到
先看attention的公式
將encoder的hidden states表示為
(h 1 , . . . , h T A)
將decoder的hidden states表示為
(d 1 , . . . , d T B) := (h T A +1 , . . . , h T A +T B).
這里最后計算得到的
就是attention的結果 對應一個樣本 就是長度為 atten_size的向量(就是所有attention輸入向量按照第三個公式的線性疊加之后的結果)那么對應batch_size的輸入 就是[batch_size, atten_size]的一個結果。
論文中提到后面會用到這個attention,
也就是說會concat attention的結果和原始hidden state的結果,那么如何使用呢,tf的做法
x = linear([inp] + attns, input_size, True)
# Run the RNN.
cell_output, state = cell(x, state)
就是說 inp是 [batch_size, input_size], attns [batch_size, attn_size] linear的輸入對應 input_size
即在linear內部經過input和attns concate之后輸出[batch_size, input_size]使得能夠x作為輸入繼續進行rnn過程
-
attention公式
繼續看attention公式 ,不要考慮batch_size就是按照一個樣本來考慮
第一個公式 對應3個舉止 W1,W2都是[attn_size, atten_size]的正方形矩陣,h,d對應 [attent_size, 1]的向量
v對應[atten_size, 1]的矩陣,
那么就是線性疊加之后做非線性變化tanh([attn_size, 1])->[attn_size, 1]最后和v做dot得到一個數值 表示u(i,t)
即對應第i個attention向量在decode的t時刻時候應該的權重大小,
第二個公式表示使用softmax做歸一化得到權重向量概率大小。
第三個公式上面已經分析。
-
tensorflow中attention的實現
- 步驟1
這里第一個問題是我們按照batch操作所以對應處理的不是一個樣本而是一批batch_size個樣本。
那么上面的操作就不能按照tf.matmul來執行了,因為[batch_size, x, y][y, 1]這樣相乘是不行的
tf的做法是使用1by1 convolution來完成,主要利用1by1 + num_channels + num_filters
關於conv2d的使用特別是配合1by1,num_channels, num_filters 這里解釋的非常清楚
http://stackoverflow.com/questions/34619177/what-does-tf-nn-conv2d-do-in-tensorflow
# To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
hidden = array_ops.reshape(
attention_states, [-1, attn_length, 1, attn_size])
hidden_features = []
v = []
attention_vec_size = attn_size # Size of query vectors for attention.
for a in xrange(num_heads):
k = variable_scope.get_variable("AttnW_%d" % a,
[1, 1, attn_size, attention_vec_size])
hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
v.append(
variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))
atention_vec_szie == attn_size
attn_size 對應 num_channels (num_channels個位置相乘加和 dot)
attention_vec_size 對應 num_filters
剛好這個conv2d的對應就是batch_size版本的attention的第一個公式里面的 W1 * h_t
Conv2d輸出[batch_size, atten_length, 1, attention_vec_size]
- def attention(query)的分析
attention(query)的輸入是rnn上一步輸出的state
輸出 attns = attention(state)對應 [batch_size, attn_size]的矩陣
對應當前步驟需要用到的attention
def attention(query):
"""Put attention masks on hidden using hidden_features and query."""
ds = [] # Results of attention reads will be stored here.
if nest.is_sequence(query): # If the query is a tuple, flatten it.
query_list = nest.flatten(query)
for q in query_list: # Check that ndims == 2 if specified.
ndims = q.get_shape().ndims
if ndims:
assert ndims == 2
query = array_ops.concat(1, query_list)
for a in xrange(num_heads):
with variable_scope.variable_scope("Attention_%d" % a):
y = linear(query, attention_vec_size, True)
y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
# Attention mask is a softmax of v^T * tanh(...).
s = math_ops.reduce_sum(
v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3])
a = nn_ops.softmax(s)
# Now calculate the attention-weighted vector d.
d = math_ops.reduce_sum(
array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden,
[1, 2])
ds.append(array_ops.reshape(d, [-1, attn_size]))
return ds
首先目前默認都是用state_is_tuple=True選項(這樣效率更高,后面state_is_tupe=False將會depreciated)
前面已經說過tf實現的state對應兩個(cell_state, hidden_state)
所以這里nest_issequence是True 對應最后處理后query 就是 [batch_size, 2 * input_size]
y = linear(query, attention_vec_size, True)
y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
對應W2dt的計算
hidden_features[a] + y 則注意是 W2dt累加到 所有的hi(attn_length個)
a對應[batdh_size, attn_length]
Reshape[batch_size, atten_length, 1, 1]
Hidden [batch_size, atten_length, 1, atten_size]
最終返回 [batch_size, attn_size]