Attention 機制


有一篇paper叫做attention is all your need,好霸氣的名字。。不過它值得,哈哈,但是我現在不打算解析paper

一開始以為是attention不理解,其實是seq-seq的encoder和decoder部分沒搞明白,感覺現在差不多了,可以自己復述了。

Attention

注意力機制就是為了解決當解碼的序列太長時,越到后面效果就越差。因為在未引入注意力機制之前,解碼時僅僅只依靠上一時刻的輸出而忽略的編碼階段每個時刻的輸出(“稱之為記憶”)。注意力機制的思想在於,希望在解碼的時刻能夠參考編碼階段的記憶,對上一時刻輸出的信息做一定的處理(也就是只注意其中某一部分),然后再喂給下一時刻做解碼處理。這樣就達到了解碼當前時刻時,僅僅只接受與當前時刻有關的輸入,類似與先對信息做了一個篩選(注意力選擇)。Encoder 把所有的輸入序列編碼成了一個c向量,然后使用c向量來進行解碼,因此, 向量中必須包含了原始序列中的所有信息,所以它的壓力其實是很大的,而且由於 RNN 容易把前面的信息“忘記”掉,所以基本的 Seq2Seq 模型,對於較短的輸入來說,效果還是可以接受的,但是在輸入序列比較長的時候, 向量存不下那么多信息,就會導致生成效果大大折扣。

解碼的時候參考編碼階段的記憶,對於encoder輸出的信息做一定的篩選,保留重要的一部分,前篩。。
既然一個上下文\(c\)向量存不了,那么就引入多個\(c\)向量,稱之為\(c_1\)\(c_2\)、…、\(c_i\),在解碼的時候,這里的\(c_i\)對應着Decoder的解碼位次,每次解碼就利用對應的\(c_i\)向量來解碼.這里的每個\(c_i\)向量其實包含了當前所輸出與輸入序列各個部分重要性的相關的信息。
還是需要借助大神的一張清晰的圖
這個是一個seq-seq模型的翻譯demo

上圖右邊的輸入部分的實線表示是訓練時的輸入,虛線表示預測時的輸入。注意向量(attention vector)是由解碼部分每個時刻的計算產生的,此處以計算第一個時刻為例。

encoder-decoder

英文句子“I am a student”被輸入到一個兩個的LSTM編碼網絡(藍色部分),經過編碼(encoder)后輸入到另外一個兩層的LSTM解碼網絡(棕色部分)。當網絡在按時刻進行翻譯decoder
(解碼)的時候,第一個時刻輸出的就是圖中的 \(h_t\)。在前面我們說到,我們希望網絡也能同我們人腦的思考過程一樣,在依次翻譯每個時刻時,網絡所“聯想”到的都是與當前時刻最相關(相似)的映射。換句話說,在神經網絡將"I am a student"翻譯成中文的過程中,當解碼到第一個時刻時,我們希望網絡僅僅只是將注意力集中到單詞"I"上,而盡可能忽略其它單詞的影響。可這說起來容易,具體該怎么做,怎么體現呢?
\(\overline{h_{1}}\)是encoder部分的hiden layer

我們知道 \(h_t\) 是第一個解碼時刻的隱含狀態,同時以上帝視角來看,最與 \(h_t\)相關的部分應該是"I"對應的編碼狀態 \(\overline{h_{1}}\) 。因此,只要網絡在解碼第一個時刻時,將注意力主要集中於 \(\overline{h_{1}}\),也就算達成目的了。但我們怎么才能讓解碼部分的網絡也能知道這一事實呢?好在此時的 \(h_t\) 與編碼部分的隱含狀態都處於同一個Embedding space,所以我們可以通過相似度對比來告訴解碼網絡:哪個編碼時刻的隱含狀態與當前解碼時刻的隱含狀態最為相似。這樣,在解碼當前時刻時,網絡就能將“注意力”盡可能多的集中於對應編碼時刻的隱含狀態。

簡單的來說做下encoder與decoder的相似度
相似度得分有兩種計算方式
\(\operatorname{score}\left(h_{t}, \bar{h}_{s}\right)=\left\{\begin{array}{ll}h_{t}^{T} W \bar{h}_{s}, & \text { [Luong's multiplicative style } ] \\ v^{T} \tanh \left(w_{1} h_{t}+w_{2} \bar{h}_{s}\right), & \text { [Bahdanau's additive style] }\end{array}\right.\)

把相似度得分進行標准化也就是使用softmax,就會得到一個0-1之間的值,找個值就是對應的attention的權重
\(\alpha_{t s}=\frac{\exp \left(\operatorname{score}\left(h_{t}, \bar{h} s\right)\right)}{\sum s^{\prime}=1^{S} \exp \left(s \operatorname{core}\left(h_{t}, \bar{h} s^{\prime}\right)\right)}\)
[Attention weights]

當網絡分別得到當前解碼時刻與所有編碼時刻對應的相似度系數之后(圖中的attention weights),再以加權就和的形式將所有的編碼狀態累加起來得到context vector,最終與 \(h_t\) 組合得到當前decoder解碼時刻的輸出。之所以要以加權求和的形式進行是因為,雖然此時的\(h_t\)僅僅只與 \(\overline{h_{1}}\)最為相關,但同樣也受其它編碼狀態的影響(例如到句型復雜的時候)。但是,若是換了應用場景,只進行對應權重乘以對應隱含狀態,不進行累加也是可以的。

context vector

也就是attention-weight與\(\overline{h_{1}}\)做mutmal矩陣相乘
\(c_{t}=\sum_{s} \alpha t s \bar{h}_{s} \quad[\text { Context vector }]\)

得到的att向量是context vector與decoder端的hiden layer :\(h_t\)做的concatenate
注意,attention vector與\(h_t\)是一一對應的關系,不然辦法做concatenate。

\(c_{t}=\sum_{s} \alpha t s \bar{h}_{s} \quad[\text { Context vector }]\)

我覺得到這里attention機制如何工作以及怎樣計算是ok的,attention就是一個找相似度的過程。但是如何得知將attention機制放到哪一層hiden layer呢?

因此,不知道放到哪個hiden layer:\(h_t\),創造一個hiden layer:\(h_t\)然后跟着網絡一起訓練,可以得到一個動態的weight

\(c_{t}\),\(h_t\)結合作為輸出,也可以直接將 \(c_{t}\)作為輸出,感覺前者的預測效果更好

concatenate的用法

import numpy as np
import keras.backend as K
import tensorflow as tf

a = K.variable(np.array([[[1, 2], [2, 3]], [[4, 4], [5, 3]]]))
b = K.variable(np.array([[[7, 4], [8, 4]], [[2, 10], [15, 11]]]))

c1 = K.concatenate([a, b], axis=0)
c2 = K.concatenate([a, b], axis=1)
c3 = K.concatenate([a, b], axis=2)
#試試默認的參數
c4 = K.concatenate([a, b])

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(c1))
    print()
    print(sess.run(c2))
    print()
    print(sess.run(c3))
    print()
    print(sess.run(c4))
以上的axis=0表示列維,1表示行維,沿着通道維度連接兩個張量。

輸出的結果會是這樣的:

[[[ 1.  2.]
  [ 2.  3.]]
 [[ 4.  4.]
  [ 5.  3.]]
 [[ 7.  4.]
  [ 8.  4.]]
 [[ 2. 10.]
  [15. 11.]]]

[[[ 1.  2.]
  [ 2.  3.]
  [ 7.  4.]
  [ 8.  4.]]
 [[ 4.  4.]
  [ 5.  3.]
  [ 2. 10.]
  [15. 11.]]]

[[[ 1.  2.  7.  4.]
  [ 2.  3.  8.  4.]]
 [[ 4.  4.  2. 10.]
  [ 5.  3. 15. 11.]]]

[[[ 1.  2.  7.  4.]
  [ 2.  3.  8.  4.]]
 [[ 4.  4.  2. 10.]
  [ 5.  3. 15. 11.]]]
另外,除聯合維度之外其它維度都必須相等。

相似度度量

Q,K,V

解釋一下attention中出現的這三個矩陣


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM