http://blog.csdn.net/u010754290/article/details/47167979
導言
在Alex Graves的這篇論文《Supervised Sequence Labelling with Recurrent Neural Networks》中對LSTM進行了綜述性的介紹,並對LSTM的Forward Pass和Backward Pass進行了公式推導。
這篇文章將用更簡潔的圖示和公式一步步對Forward和Backward進行推導,相信讀者看完之后能對LSTM有更深入的理解。
如果讀者對LSTM的由來和原理存在困惑,推薦DarkScope的這篇博客:《RNN以及LSTM的介紹和公式梳理》
一、LSTM的基礎結構
LSTM的結構中每個時刻的隱層包含了多個memory blocks(一般我們采用一個block),每個block包含了多個memory cell,每個memory cell包含一個Cell和三個gate,一個基礎的結構示例如下圖:
一個memory cell只能產出一個標量值,一個block能產出一個向量。
二、LSTM的前向傳播(Forward Pass)
1. 引入
首先我們在上述LSTM的基礎結構之上構造時序結構,這樣讓讀者更清晰地看到Recurrent的結構:
這里我們有幾個約定:
- 每個時刻的隱層包含一個block
- 每個block包含一個memory cell
下面前向傳播我們則從Input開始,逐個求解Input Gate、Forget Gate、Cells Gate、Ouput Gate和最終的Output
這里需要申明的一點,推導過程嚴格按照上述圖示LSTM的結構;論文中對相較於該文章的推導過程會有增加一些項,在每一個公式不一致的地方我都會有相應說明。
2. Input Gate(ι) 的計算
Input Gate接受兩個輸入:
- 當前時刻的Input作為輸入:xt
- 上一時刻同一block內所有Cell作為輸入:st−1c
該案例中每層僅有單個Block、單個cemory cell,可以忽略∑Cc=1,以下Forget Gate和Output Gate做相同處理。
最終Input Gate的輸出為:
這里Input Gate還可以接受上一個時刻中不同block的輸出bt−1h作為輸入,論文中atι會增加一項∑Hh=1ωhιbt−1h。
3. Forget Gate(ϕ) 的計算
Forget Gate接受兩個輸入:
- 當前時刻的Input作為輸入:xt
- 上一時刻同一block內所有Cell作為輸入:st−1c
最終Forget Gate的輸出為:
這里Input Gate還可以接受上一個時刻中不同block的輸出bt−1h作為輸入,論文中atϕ會增加一項∑Hh=1ωhϕbt−1h。
4. Cell(c) 的計算
Cell的計算稍有些復雜,接受兩個輸入:
- Input Gate和Input輸入的乘積
- Forget Gate和上一時刻對應Cell輸出的乘積
最終Cell的輸出為:
這里Input Gate還可以接受上一個時刻中不同block的輸出bt−1h作為輸入,論文中atc會增加一項∑Hh=1ωhcbt−1h。
5. Output Gate(ω) 的計算
Output Gate接受兩個輸入:
- 當前時刻的Input作為輸入:xt
- 當前時刻同一block內所有Cell作為輸入:stc
這里Output Gate接受“當前時刻Cell的輸出”而不是“上一時刻Cell的輸出”,是由於此時Cell的結果已經產出,我們控制Output Gate的輸出直接采用Cell當前的結果就行了,無須使用上一時刻。
最終Output Gate的輸出為:
這里Cell還可以接受上一個時刻中其他gate鏈接過來的邊,論文中atϕ會增加一項∑Hh=1ωhϕbt−1h,這里H是泛指t-1時刻的Cell或三個Gate。
6. Cell Output(c) 的計算
Cell Output的計算即將Output Gate和Cell做乘積即可。
最終Cell Output為:
7. 小結
至此,整個Block從Input到Output整個Forward Pass已經結束,其中涉及三個Gate和中間Cell的計算,需要注意的是三個Gate使用的激活函數是f,而Input的激活函數是g、Cell輸出的激活函數是h。
這里讀者需要注意,在整個計算過程中,當前時刻的三個Gate均可以從上一時刻的任意Gate中接受輸入,在公式中存在體現,但是在圖示中並未畫出相應的邊。我們可以認為只有上一時刻的Cell才和當前時刻的Cell或三個Gate相連。
三、LSTM的反向傳播(Backward Pass)
1. 引入
此處在論文中使用“Backward Pass”一詞,但其實即Back Propagation過程,利用鏈式求導求解整個LSTM中每個權重的梯度。
2. 損失函數的選擇
為了通用起見,在此我們僅展示多分類問題的損失函數的選擇,對於網絡的最終輸出我們利用softmax方程計算結果屬於某一類的概率(此時結果屬於k個類別的概率和為1)。
注意,yk對ak的偏導為∂yk′∂ak=ykδkk′−ykyk′(δkk′當k==k′時為1,其他為0)
其中,對於網絡輸出a1,a2,...對應我們可以得到p(C1|x),p(C2|x),...,即給定輸入x輸出類別為C1,C2,...的概率。
這樣損失函數(Loss Function)就很好定義了:對於k∈1,2,...,K,網絡輸出的類別為k概率為yk,而真實值zk:
3. 權重的更新
對於神經網絡中的每一個權重,我們都需要找到對應的梯度,從而通過不斷地用訓練樣本進行隨機梯度下降找到全局最優解,那么首先我們需要知道哪些權重需要更新。
一般層次分明的神經網絡有input層、hidden層和output層,層與層之間的權重比較直觀;但在LSTM中通過公式才能找到對應的權重,和圖示中的邊並不是一一對應,下面我將LSTM的單個Block中需要更新的權重在圖示上標示了出來:
為了方便起見,這里需要申明的是:我們僅考慮上一時刻的Cell僅和當前時刻的Cell和三個Gate相連。
2. Cell Output的梯度
首先我們計算每一個輸出類別的梯度: