導讀
目前采用編碼器-解碼器 (Encode-Decode) 結構的模型非常熱門,是因為它在許多領域較其他的傳統模型方法都取得了更好的結果。這種結構的模型通常將輸入序列編碼成一個固定長度的向量表示,對於長度較短的輸入序列而言,該模型能夠學習出對應合理的向量表示。然而,這種模型存在的問題在於:當輸入序列非常長時,模型難以學到合理的向量表示。
在這篇博文中,我們將探索加入LSTM/RNN模型中的attention機制是如何克服傳統編碼器-解碼器結構存在的問題的。
通過閱讀這篇博文,你將會學習到:
- 傳統編碼器-解碼器結構存在的問題及如何將輸入序列編碼成固定的向量表示;
- Attention機制是如何克服上述問題的,以及在模型輸出時是如何考慮輸出與輸入序列的每一項關系的;
- 基於attention機制的LSTM/RNN模型的5個應用領域:機器翻譯、圖片描述、語義蘊涵、語音識別和文本摘要。
讓我們開始學習吧。
一、長輸入序列帶來的問題
使用傳統編碼器-解碼器的RNN模型先用一些LSTM單元來對輸入序列進行學習,編碼為固定長度的向量表示;然后再用一些LSTM單元來讀取這種向量表示並解碼為輸出序列。
采用這種結構的模型在許多比較難的序列預測問題(如文本翻譯)上都取得了最好的結果,因此迅速成為了目前的主流方法。
例如:
- Sequence to Sequence Learning with Neural Networks, 2014
- Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation, 2014
這種結構在很多其他的領域上也取得了不錯的結果。然而,它存在一個問題在於:輸入序列不論長短都會被編碼成一個固定長度的向量表示,而解碼則受限於該固定長度的向量表示。
這個問題限制了模型的性能,尤其是當輸入序列比較長時,模型的性能會變得很差(在文本翻譯任務上表現為待翻譯的原始文本長度過長時翻譯質量較差)。
“一個潛在的問題是,采用編碼器-解碼器結構的神經網絡模型需要將輸入序列中的必要信息表示為一個固定長度的向量,而當輸入序列很長時則難以保留全部的必要信息(因為太多),尤其是當輸入序列的長度比訓練數據集中的更長時。”
— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015
二、使用Attention機制
Attention機制的基本思想是,打破了傳統編碼器-解碼器結構在編解碼時都依賴於內部一個固定長度向量的限制。
Attention機制的實現是通過保留LSTM編碼器對輸入序列的中間輸出結果,然后訓練一個模型來對這些輸入進行選擇性的學習並且在模型輸出時將輸出序列與之進行關聯。
換一個角度而言,輸出序列中的每一項的生成概率取決於在輸入序列中選擇了哪些項。
“在文本翻譯任務上,使用attention機制的模型每生成一個詞時都會在輸入序列中找出一個與之最相關的詞集合。之后模型根據當前的上下文向量 (context vectors) 和所有之前生成出的詞來預測下一個目標詞。
… 它將輸入序列轉化為一堆向量的序列並自適應地從中選擇一個子集來解碼出目標翻譯文本。這感覺上像是用於文本翻譯的神經網絡模型需要“壓縮”輸入文本中的所有信息為一個固定長度的向量,不論輸入文本的長短。”
— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015
雖然模型使用attention機制之后會增加計算量,但是性能水平能夠得到提升。另外,使用attention機制便於理解在模型輸出過程中輸入序列中的信息是如何影響最后生成序列的。這有助於我們更好地理解模型的內部運作機制以及對一些特定的輸入-輸出進行debug。
“論文提出的方法能夠直觀地觀察到生成序列中的每個詞與輸入序列中一些詞的對齊關系,這可以通過對標注 (annotations) 權重參數可視化來實現…每個圖中矩陣的每一行表示與標注相關聯的權重。由此我們可以看出在生成目標詞時,源句子中的位置信息會被認為更重要。”
— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015
三、大型圖片帶來的問題
被廣泛應用於計算機視覺領域的卷積神經網絡模型同樣存在類似的問題: 對於特別大的圖片輸入,模型學習起來比較困難。
由此,一種啟發式的方法是將在模型做預測之前先對大型圖片進行某種近似的表示。
“人類的感知有一個重要的特性是不會立即處理外界的全部輸入,相反的,人類會將注意力專注於所選擇的部分來得到所需要的信息,然后結合不同時間段的局部信息來建立一個內部的場景表示,從而引導眼球的移動及做出決策。”
這種啟發式方法某種程度上也可以認為是考慮了attention,但在這篇博文中,這種方法並不認為是基於attention機制的。
基於attention機制的相關論文如下:
- Recurrent Models of Visual Attention, 2014
- DRAW: A Recurrent Neural Network For Image Generation, 2014
- Multiple Object Recognition with Visual Attention, 2014
四、基於Attention模型的應用實例
這部分將列舉幾個具體的應用實例,介紹attention機制是如何用在LSTM/RNN模型來進行序列預測的。
1. Attention在文本翻譯任務上的應用
文本翻譯這個實例在前面已經提過了。
給定一個法語的句子作為輸入序列,需要輸出翻譯為英語的句子。Attention機制被用在輸出輸出序列中的每個詞時會專注考慮輸入序列中的一些被認為比較重要的詞。
我們對原始的編碼器-解碼器模型進行了改進,使其有一個模型來對輸入內容進行搜索,也就是說在生成目標詞時會有一個編碼器來做這個事情。這打破了之前的模型是基於將整個輸入序列強行編碼為一個固定長度向量的限制,同時也讓模型在生成下一個目標詞時重點考慮輸入中相關的信息。
— Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015
Attention在文本翻譯任務(輸入為法語文本序列,輸出為英語文本序列)上的可視化(圖片來源於Dzmitry Bahdanau, et al., Neural machine translation by jointly learning to align and translate, 2015)
2. Attention在圖片描述上的應用
與之前啟發式方法不同的是,基於序列生成的attention機制可以應用在計算機視覺相關的任務上,幫助卷積神經網絡重點關注圖片的一些局部信息來生成相應的序列,典型的任務就是對一張圖片進行文本描述。
給定一張圖片作為輸入,輸出對應的英文文本描述。Attention機制被用在輸出輸出序列的每個詞時會專注考慮圖片中不同的局部信息。
我們提出了一種基於attention的方法,該方法在3個標准數據集上都取得了最佳的結果……同時展現了attention機制能夠更好地幫助我們理解模型地生成過程,模型學習到的對齊關系與人類的直觀認知非常的接近(如下圖)。
— Show, Attend and Tell: Neural Image Caption Generation with Visual Attention, 2016
Attention在圖片描述任務(輸入為圖片,輸出為描述的文本)上的可視化(圖片來源於Attend and Tell: Neural Image Caption Generation with Visual Attention, 2016)
3. Attention在語義蘊涵 (Entailment) 中的應用
給定一個用英文描述的前提和假設作為輸入,輸出假設與前提是否矛盾、是否相關或者是否成立。
舉個例子:
前提:在一個婚禮派對上拍照
假設:有人結婚了
該例子中的假設是成立的。
Attention機制被用於關聯假設和前提描述文本之間詞與詞的關系。
我們提出了一種基於LSTM的神經網絡模型,和把每個輸入文本都獨立編碼為一個語義向量的模型不同的是,該模型同時讀取前提和假設兩個描述的文本序列並判斷假設是否成立。我們在模型中加入了attention機制來找出假設和前提文本中詞/短語之間的對齊關系。……加入attention機制能夠使模型在實驗結果上有2.6個點的提升,這是目前數據集上取得的最好結果…
Attention在語義蘊涵任務(輸入是前提文本,輸出是假設文本)上的可視化(圖片來源於Reasoning about Entailment with Neural Attention, 2016)
4. Attention在語音識別上的應用
給定一個英文的語音片段作為輸入,輸出對應的音素序列。
Attention機制被用於對輸出序列的每個音素和輸入語音序列中一些特定幀進行關聯。
…一種基於attention機制的端到端可訓練的語音識別模型,能夠結合文本內容和位置信息來選擇輸入序列中下一個進行編碼的位置。該模型有一個優點是能夠識別長度比訓練數據長得多的語音輸入。
Attention在語音識別任務(輸入是音幀,輸出是音素的位置)上的可視化(圖片來源於Attention-Based Models for Speech Recognition, 2015)
5. Attention在文本摘要上的應用
給定一篇英文文章作為輸入序列,輸出一個對應的摘要序列。
Attention機制被用於關聯輸出摘要中的每個詞和輸入中的一些特定詞。
… 在最近神經網絡翻譯模型的發展基礎之上,提出了一個用於生成摘要任務的基於attention的神經網絡模型。通過將這個概率模型與一個生成式方法相結合來生成出准確的摘要。
— A Neural Attention Model for Abstractive Sentence Summarization, 2015
Attention在文本摘要任務(輸入為文章,輸出為文本摘要)上的可視化(圖片來源於A Neural Attention Model for Abstractive Sentence Summarization, 2015)
五、Attention的數學解釋
1. 原來的Encoder–Decoder
在這個模型中,encoder只將最后一個輸出遞給了decoder,這樣一來,decoder就相當於對輸入只知道梗概意思,而無法得到更多輸入的細節,比如輸入的位置信息。所以想想就知道了,如果輸入的句子比較短、意思比較簡單,翻譯起來還行,長了復雜了就做不好了嘛。
2. 對齊問題
前面說了,只給我遞來最后一個輸出,不好;但如果把每個step的輸出都傳給我,又有一個問題了,怎么對齊?
什么是對齊?比如說英文翻譯成中文,假設英文有10個詞,對應的中文翻譯只有6個詞,那么就有了哪些英文詞對哪些中文詞的問題了嘛。
傳統的翻譯專門有一塊是搞對齊的,是一個比較獨立的task(傳統的NLP基本上每一塊都是獨立的task啦)。
3. attention機制
我們從輸出端,即decoder部分,倒過來一步一步看公式。
$$ S_t=f(S_{t-1}, y_{t-1}, c_t) \tag{1} $$
$S_t$是指decoder在$t$時刻的狀態輸出,$S_{t-1}$是指decoder在$t-1$時刻的狀態輸出,$y_{t-1}$是$t-1$時刻的label(注意是label,不是我們輸出的$y$),$c_t$看下一個公式,$f$是一個RNN。
$$ {c_{t}} = \sum\limits_{j = 1}^{{T_x}} {{a_{tj}}{h_j}} \tag{2} $$
$h_j$是指第$j$個輸入在encoder里的輸出,$a_{tj}$是一個權重
$$ {a_{tj}} = \frac{{exp \left( {{e_{tj}}} \right)}}{{\sum\nolimits_{k = 1}^{{T_x}} {exp \left( {{e_{tk}}} \right)} }} \tag{3}$$
這個公式跟softmax是何其相似,道理是一樣的,是為了得到條件概率$P(a|e)$,這個$a$的意義是當前這一步decoder對齊第$j$個輸入的程度。
最后一個公式,
$$ e_{tj} = g(S_{t-1}, h_j) = V\cdot \tanh { \left( W\cdot h_j+U\cdot S_{t-1}+b \right) } \tag{4}$$
這個$g$可以用一個小型的神經網絡來逼近,它用來計算$S_{t-1}$, $h_j$這兩者的關系分數,如果分數大則說明關注度較高,注意力分布就會更加集中在這個輸入單詞上,這個函數在文章Neural Machine Translation by Jointly Learning to Align and Translate(2014)中稱之為校准模型(alignment model),文中提到這個函數是RNN前饋網絡中的一系列參數,在訓練過程會訓練這些參數, 基於Attention-Based LSTM模型的文本分類技術的研究(2016)給出了上式的右側部分作為拓展。
好了,把四個公式串起來看,這個attention機制可以總結為一句話:當前一步輸出$S_t$應該對齊哪一步輸入,主要取決於前一步輸出$S_{t-1}$和這一步輸入的encoder結果$h_j$。
看了這個方法的感受是,計算力發達的這個年代,真是什么復雜的東西都有人敢試了啊。這要是放在以前,得跑多久才能收斂啊......
神經網絡搞NLP雖然還有諸多受限的地方,但這種end-to-end 的one task方式,太吸引人,有前途。
進一步的閱讀
如果你想進一步地學習如何在LSTM/RNN模型中加入attention機制,可閱讀以下論文:
- Attention and memory in deep learning and NLP
- Attention Mechanism
- Survey on Attention-based Models Applied in NLP
- What is exactly the attention mechanism introduced to RNN? (來自Quora)
- What is Attention Mechanism in Neural Networks?
目前Keras官方還沒有單獨將attention模型的代碼開源,下面有一些第三方的實現:
- Deep Language Modeling for Question Answering using Keras
- Attention Model Available!
- Keras Attention Mechanism
- Attention and Augmented Recurrent Neural Networks
- How to add Attention on top of a Recurrent Layer (Text Classification)
- Attention Mechanism Implementation Issue
- Implementing simple neural attention model (for padded inputs)
- Attention layer requires another PR
- seq2seq library
總結
通過這篇博文,你應該學習到了attention機制是如何應用在LSTM/RNN模型中來解決序列預測存在的問題。
具體而言,采用傳統編碼器-解碼器結構的LSTM/RNN模型存在一個問題:不論輸入長短都將其編碼成一個固定長度的向量表示,這使模型對於長輸入序列的學習效果很差(解碼效果很差)。而attention機制則克服了上述問題,原理是在模型輸出時會選擇性地專注考慮輸入中的對應相關的信息。使用attention機制的方法被廣泛應用在各種序列預測任務上,包括文本翻譯、語音識別等。
感謝原作者Jason Brownlee。原文鏈接見:Attention in Long Short-Term Memory Recurrent Neural Networks
轉載:http://www.jeyzhang.com/understand-attention-in-rnn.html