原文鏈接:https://zhuanlan.zhihu.com/p/353680367
此篇文章內容源自 Attention Is All You Need,若侵犯版權,請告知本人刪帖。
原論文下載地址:
https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
摘要
主要的序列轉導模型是基於復雜的遞歸或卷積神經網絡的,這些網絡包含一個編碼器和一個解碼器。上述模型中最佳的模型使用了注意力機制連接了編碼器和解碼器。作者提出一個新的簡單的網絡架構,即 Tranformer,僅基於注意力機制,完全摒棄了遞歸和卷積。在兩個機器翻譯任務上的實驗表明,這些模型在質量上具有優勢,同時具有更高的可並行性,並且所需的訓練時間明顯更少。作者提出的模型在 WMT 2014 英語-德語 翻譯任務上達到了 28.4 BLEU,比現有的最佳結果(包括集成)提高了2 BLEU 以上。在WMT 2014英語-法語翻譯任務中,我們的模型在 8 個 gpu 上訓練3.5天后,建立了一個新的單模型最佳的 41.0 BLEU 分數,這只相當於文獻中最好模型所需訓練成本的一小部分。
1. 模型架構
最具競爭力的神經序列轉導模型具有編碼器-解碼器結構。編碼器將符號表示的輸入序列 映射為一個連續表示的序列
。給定
,解碼器將生成一個符號表示的輸出序列
,每次生成輸出序列中的一個元素。在每一步模型都是自回歸的,在生成下一個符號時將之前生成的所有符號作為附加輸入。
Transformer 遵從以下所述的整體架構,該架構針對編碼器和解碼器使用了堆疊的自注意力機制、逐點計算和全連接層,編碼器和解碼器分別如圖 1 的左右部分所示。

1.1 編碼器和解碼器堆疊
編碼器
編碼器由 N 個完全相同的層堆疊而成,此處 N=6。每一層都是一個雙子層結構。第一個子層是一個多頭部自注意力機制,第二個子層是一個簡單的按位全連接的前饋網絡。作者針對每個雙子層結構中的各個子層使用了殘差連接[1],並在殘差連接后使用了層歸一化[2]。層歸一化是對每個子層的輸出執行 ,其中
是子層自身實現的功能。為了易於殘差連接,模型中的所有子層以及嵌入層均生成
維度的輸出,此處
。
解碼器
解碼器也由 N 個完全相同的層堆疊而成,此處 N=6。基於編碼器中的雙子層結構,解碼器插入了第三個子層,該子層針對編碼器模塊的輸出執行多頭部注意力。與編碼器類似,作者也在解碼器的每個子層結構的各個子層中使用了殘差連接,並執行了層歸一化。作者也更新了解碼器模塊中的自注意力子層,防止當前位置影響后續位置。這種掩碼機制連同輸出嵌入具有一個位置偏移的實際情況,保證了針對位置 的預測可以只依賴於位置小於
的已知輸出。
2. 注意力
注意力功能可以被描述為一個查詢和鍵值對到一個輸出的映射,其中查詢、鍵、值和輸出全部是向量。輸出是值的加權和,其中分配給每個值的權重是根據查詢和對應鍵的兼容性函數計算的。
2.1 縮放點積注意力
作者稱其提出的特殊注意力為“縮放點積注意力(Scaled Dot-Product Attention)”,如圖 2 所示。輸入由 維度的查詢和鍵,以及
維度的值組成。計算每個查詢和所有鍵的點積,每個查詢對應的點積結果除以
,然后施加 softmax 函數來獲取該值的權重。其他查詢遵循此方法。

實際中,作者針對一個查詢集合同時計算其注意力函數,查詢集合拼接為矩陣 Q。鍵和值也拼接為矩陣 K 和 V。按照公式 (1) 計算輸出矩陣。
(1)
假定有 b 個查詢,那么 Q、K 的維度為,V 的維度為
。
的維度為
,softmaxt 函數的輸出維度也是
,Attention 函數的輸出維度就是
。
兩個最常用的注意力函數是加法注意力[3]和點積(乘法)注意力。除了縮放因子 ,點積注意力與作者的算法是一致的。加法注意力使用一個具有單個隱藏層的前饋網絡計算兼容性函數。雖然兩者在理論復雜性上相似,但是點積注意力在實際中會更快更節省空間,因為它可以使用高度優化的矩陣乘法代碼來實現。
盡管當 取值較小時,兩種注意力機制表現相似,但是當
取值較大且不進行縮放時,加法注意力要優於點積注意力。作者猜測,當
較大值時,點積會大幅度增大,並且將 softmax 函數推入其梯度極小的區域(即其值趨於1或0)。為了緩解該影響,作者通過
縮放點積。
2.2 多頭部注意力
作者發現使用不同的可學習線性投影將查詢、鍵、值線性投影 h 次,在每次投影時分別得到 維度的查詢、鍵、值會有更好的效果,而不是使用
維度的鍵、值、查詢單獨執行一次注意力函數。針對投影后的每一組查詢、鍵、值,作者並行地執行注意力函數,生成一個
維度的輸出值。這些輸出值被拼接在一起,然后再一次被投影,生成最終值,如圖 2 所示。
多頭部注意力可以使模型綜合考慮來自不同位置的不同表示子空間的信息。使用單獨的注意力頭部,平均操作(層歸一化時的操作)會抑制這一特性。
其中,投影對應的參數矩陣為 ,
,
和
。此處的 Concat 操作是各個矩陣按照列來拼接,這樣可以在保證每個矩陣不變的前提下,得到一個列維度為
的矩陣。
假定有 b 個查詢,那么 Q、K 、V的維度為,
和
的維度為
,
的維度為
。根據公式 (1),
的維度為
。Concat 操作為按列拼接,其結果維度為
。MultiHead 的輸出維度為
。
在本文中,作者使用了 的並行注意力層,或可以稱為 8 頭部注意力。其他參數為
。由於每個頭部縮減了維度,整體的計算開銷和維度不變時的單頭部注意力類似。
2.3 模型中注意力機制的應用
本文中 Transformer 通過三種不同方式使用多頭部注意力:
- 在“編碼器-解碼器注意力”層,查詢來自前一個解碼層,存儲的鍵和值來自編碼器的輸出。這使解碼器中每個位置都能參考輸入序列中的所有位置。這模擬了序列至序列模型的典型編碼器-解碼器注意力機制,例如文獻[4][3][5]。
- 編碼器包含了自注意力層。在一個自注意力層中,所有的鍵、值、查詢都來自相同的空間,在本文的模型中,三者均來自編碼器中其前一層的輸出。編碼器中每個位置都可以參考編碼器中前一層所包含的所有位置。
- 類似地,解碼器中的自注意力層可以使解碼器中的每個位置都參考解碼器中的直至(包含)當前位置的全部位置。為了保持解碼器的自回歸特性,我們需要防止信息在解碼器中向左流動。作者通過在縮放點積注意力中屏蔽(設置為
) softmax 的輸入中對應無效連接的所有值來實現這一功能。參見圖 2。
3. 按位前饋網絡
除了注意力子層,作者提出的編碼器和解碼器中的每個層都包含一個全連接前饋網絡,該網絡分別且獨立地應用於每個位置。該網絡包含兩個線性變換,並且兩個變換之間存在一個 ReLU 激活。如公式 (2)。
(2)
盡管不同位置上執行的線性變換操作是一致的,但是它們在各層之間使用不同的參數。另一種實現方法是使用兩個卷積核尺寸為 1 的卷積。輸入和輸出的維度是 ,內部層的維度是
。
4. 嵌入和softmax
與其他序列轉導模型相似,作者使用可學習嵌入來將輸入符號和輸出符號轉換為維度為 的向量。作者也使用常用的可學習線性變換和 softmax 函數將編碼器輸出轉換為下一個被預測符號的概率。在作者的模型中,作者在兩個嵌入層和前一個 softmax 線性變換之間共享相同的權重矩陣,和文獻[6]類似。在嵌入層,作者將這些權重乘以
。
5. 位置編碼
由於作者提出的網絡不包含遞歸和卷積,為了使模型能夠利用序列的順序信息,必須向序列中注入符號的相對位置或絕對位置信息。為此,作者在編碼器和解碼器底部向輸入嵌入添加了“位置編碼”。位置編碼與嵌入具有相同的維度 ,確保了二者可以相加。現有多種可學習和固定位置編碼可供選擇[5]。
在本文中,作者選用了不同頻率的 sine 和 cosine 函數:
其中 pos 表示位置,i 表示維度。也就是說,位置編碼的每個維度對應一個正弦信號。波長構成了 至
的幾何級數。作者選擇該函數是因為猜測其會使模型易於學會如何參考相對位置,因為對於任何固定偏移 k,
可以認為是
的線性函數。
作者也試驗了用可學習位置嵌入[5]替代正弦信號編碼,發現二者產生的結果幾乎相同(表 3 E 行)。作者選擇正弦信號編碼因為其可能會使模型推斷出比訓練過程中遇到的序列長度更長的序列。

表 3 A 行驗證了不同的注意力頭部、注意力鍵維度、注意力值維度的數目,在保持總計算量不變(如“2.2 多頭部注意力”所述)的前提下模型的性能。盡管單頭部注意力比最佳設置低了 0.9 BLEU,但是過多的頭部也會降低模型性能。
表 3 B 行說明了,降低注意力的鍵的大小 會損傷模型質量。這表明確定兼容性並不容易,並且比點乘積更復雜的兼容性功能可能會有所幫助。作者進一步在 C 行和 D 行發現,正如所期待的,模型越大性能越好,dropout 能有效地避免過擬合。在 E 行,作者將正弦位置編碼替換為可學習位置嵌入[5],其結果與基線模型幾乎相同。
6. 為何選擇自注意力
這一部分作者從多個角度比較了自注意力層和常用於將變長的符號表示序列 映射為另一個等長序列
的遞歸層、卷積層,其中
,例如典型的序列轉換編碼器或解碼器中的隱藏層。以下三個需求促使作者使用了自注意力。
第一個是每層的總計算復雜度。第二個是可以並行的計算量,與順序操作所需的最小數目一致。
第三個是網絡中遠程依賴關系之間的路徑長度。學習遠程依賴是許多序列轉導模型的一個主要挑戰。影響學習這種依賴的一個重要因素是網絡中前饋信號和后饋信號必須遍歷的路徑長度。輸入序列和輸出序列中任一位置組合之間的路徑越短,越易於學習遠程依賴[7]。因此,作者也比較了由不同類型的層構成的網絡中,任意兩個輸入輸出位置之間的最大路徑長度。
如表 1 所示,一個自注意力層通過常量級的順序執行操作連接了所有位置,而一個遞歸層需要 個順序操作。就計算復雜度來說,當序列長度 n 比表示維度 d 小時,自注意力層比遞歸層快,而這種情況在機器翻譯領域的 SOTA 模型中使用的句子表征中很常見,例如詞片[4]和字節對[8]表示。為了提高涉及到特別長的序列的任務的計算性能,自注意力可以限制為只考慮以與輸出位置相對應的輸入位置為中心,輸入序列中大小為 r 的鄰域。這將增加最大路徑長度至
。作者計划在未來的工作中進一步研究這種方法。

一個單獨的卷積核寬度 的卷積層無法連接所有輸入輸出位置對。為了連接所有輸入輸出位置對,在連續卷積的情況下,需要
個卷積層堆疊在一起,在離散卷積[9]的情況下需要
個卷積層,這就增加了網絡中任一兩個位置間的最長路徑的長度。卷積層的開銷通常是遞歸層的 k 倍。可分離卷積[10]極大地降低了復雜度至
。然而,即使 k = n 時,可分離卷積的復雜度和一個自注意力層與一個逐點前饋層的組合的復雜度一致,即作者在模型中使用的方法。
作為附帶好處,自注意力可以生成更具解釋性的模型。作者從模型中檢查注意力分布,並在附錄中介紹和討論示例。各個注意力頭部不僅清楚地學會執行不同任務,許多注意力頭部還似乎表現出與句子的句法和語義結構相關的行為。
參考
- ^Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 770–778, 2016.
- ^Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
- ^abDzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. CoRR, abs/1409.0473, 2014.
- ^abYonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V Le, Mohammad Norouzi, Wolfgang Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, et al. Google’s neural machine translation system: Bridging the gap between human and machine translation. arXiv preprint arXiv:1609.08144, 2016.
- ^abcdJonas Gehring, Michael Auli, David Grangier, Denis Yarats, and Yann N. Dauphin. Convolutional sequence to sequence learning. arXiv preprint arXiv:1705.03122v2, 2017.
- ^Ofir Press and Lior Wolf. Using the output embedding to improve language models. arXiv preprint arXiv:1608.05859, 2016.
- ^Sepp Hochreiter, Yoshua Bengio, Paolo Frasconi, and Jürgen Schmidhuber. Gradient flow in recurrent nets: the difficulty of learning long-term dependencies, 2001.
- ^Rico Sennrich, Barry Haddow, and Alexandra Birch. Neural machine translation of rare words with subword units. arXiv preprint arXiv:1508.07909, 2015.
- ^Nal Kalchbrenner, Lasse Espeholt, Karen Simonyan, Aaron van den Oord, Alex Graves, and Koray Kavukcuoglu. Neural machine translation in linear time. arXiv preprint arXiv:1610.10099v2, 2017.
- ^Francois Chollet. Xception: Deep learning with depthwise separable convolutions. arXiv preprint arXiv:1610.02357, 2016.