lstm公式推導


http://blog.csdn.net/u010754290/article/details/47167979

導言

在Alex Graves的這篇論文《Supervised Sequence Labelling with Recurrent Neural Networks》中對LSTM進行了綜述性的介紹,並對LSTM的Forward Pass和Backward Pass進行了公式推導。

這篇文章將用更簡潔的圖示和公式一步步對Forward和Backward進行推導,相信讀者看完之后能對LSTM有更深入的理解。

如果讀者對LSTM的由來和原理存在困惑,推薦DarkScope的這篇博客:《RNN以及LSTM的介紹和公式梳理》

一、LSTM的基礎結構

LSTM的結構中每個時刻的隱層包含了多個memory blocks(一般我們采用一個block),每個block包含了多個memory cell,每個memory cell包含一個Cell和三個gate,一個基礎的結構示例如下圖: 
image

一個memory cell只能產出一個標量值,一個block能產出一個向量。

二、LSTM的前向傳播(Forward Pass)

1. 引入

首先我們在上述LSTM的基礎結構之上構造時序結構,這樣讓讀者更清晰地看到Recurrent的結構:

LSTM的整體結構

這里我們有幾個約定:

  1. 每個時刻的隱層包含一個block
  2. 每個block包含一個memory cell

下面前向傳播我們則從Input開始,逐個求解Input Gate、Forget Gate、Cells Gate、Ouput Gate和最終的Output

這里需要申明的一點,推導過程嚴格按照上述圖示LSTM的結構;論文中對相較於該文章的推導過程會有增加一些項,在每一個公式不一致的地方我都會有相應說明。

2. Input Gate(ι) 的計算

Input Gate接受兩個輸入:

  1. 當前時刻的Input作為輸入:xt
  2. 上一時刻同一block內所有Cell作為輸入:st1c

該案例中每層僅有單個Block、單個cemory cell,可以忽略Cc=1,以下Forget Gate和Output Gate做相同處理。

Input Gate

最終Input Gate的輸出為:

 

atι=i=1Iωiιxti+c=1Cωcιst1c

 

 

btι=f(atι)

 

這里Input Gate還可以接受上一個時刻中不同block的輸出bt1h作為輸入,論文中atι會增加一項Hh=1ωhιbt1h

3. Forget Gate(ϕ) 的計算

Forget Gate接受兩個輸入:

  1. 當前時刻的Input作為輸入:xt
  2. 上一時刻同一block內所有Cell作為輸入:st1c

Forget Gate

最終Forget Gate的輸出為:

 

atϕ=i=1Iωiϕxti+c=1Cωcϕst1c

 

 

btϕ=f(atϕ)

 

這里Input Gate還可以接受上一個時刻中不同block的輸出bt1h作為輸入,論文中atϕ會增加一項Hh=1ωhϕbt1h

4. Cell(c) 的計算

Cell的計算稍有些復雜,接受兩個輸入:

  1. Input Gate和Input輸入的乘積
  2. Forget Gate和上一時刻對應Cell輸出的乘積

Cell

最終Cell的輸出為:

 

atc=i=1Iωicxti

 

 

stc=btϕst1c+btιg(atc)

 

這里Input Gate還可以接受上一個時刻中不同block的輸出bt1h作為輸入,論文中atc會增加一項Hh=1ωhcbt1h

5. Output Gate(ω) 的計算

Output Gate接受兩個輸入:

  1. 當前時刻的Input作為輸入:xt
  2. 當前時刻同一block內所有Cell作為輸入:stc

這里Output Gate接受“當前時刻Cell的輸出”而不是“上一時刻Cell的輸出”,是由於此時Cell的結果已經產出,我們控制Output Gate的輸出直接采用Cell當前的結果就行了,無須使用上一時刻。

Output Gate

最終Output Gate的輸出為:

 

atω=i=1Iωiωxti+c=1Cωcωstc

 

 

btω=f(atω)

 

這里Cell還可以接受上一個時刻中其他gate鏈接過來的邊,論文中atϕ會增加一項Hh=1ωhϕbt1h,這里H是泛指t-1時刻的Cell或三個Gate。

6. Cell Output(c) 的計算

Cell Output的計算即將Output Gate和Cell做乘積即可。

Cell Output

最終Cell Output為:

 

btc=btωh(stc)

 

7. 小結

至此,整個Block從Input到Output整個Forward Pass已經結束,其中涉及三個Gate和中間Cell的計算,需要注意的是三個Gate使用的激活函數是f,而Input的激活函數是g、Cell輸出的激活函數是h

這里讀者需要注意,在整個計算過程中,當前時刻的三個Gate均可以從上一時刻的任意Gate中接受輸入,在公式中存在體現,但是在圖示中並未畫出相應的邊。我們可以認為只有上一時刻的Cell才和當前時刻的Cell或三個Gate相連。 
前向小結

三、LSTM的反向傳播(Backward Pass)

1. 引入

此處在論文中使用“Backward Pass”一詞,但其實即Back Propagation過程,利用鏈式求導求解整個LSTM中每個權重的梯度。

2. 損失函數的選擇

為了通用起見,在此我們僅展示多分類問題的損失函數的選擇,對於網絡的最終輸出我們利用softmax方程計算結果屬於某一類的概率(此時結果屬於k個類別的概率和為1)。

 

p(Ck|x)=yk=eakKk=1eak

 

注意,ykak的偏導為ykak=ykδkkykykδkkk==k時為1,其他為0)

其中,對於網絡輸出a1,a2,...對應我們可以得到p(C1|x),p(C2|x),...,即給定輸入x輸出類別為C1,C2,...的概率。

這樣損失函數(Loss Function)就很好定義了:對於k1,2,...,K,網絡輸出的類別為k概率為yk,而真實值zk

 

(x,z)=lnp(z|x)=k=1Kzklnyk

 

3. 權重的更新

對於神經網絡中的每一個權重,我們都需要找到對應的梯度,從而通過不斷地用訓練樣本進行隨機梯度下降找到全局最優解,那么首先我們需要知道哪些權重需要更新。

一般層次分明的神經網絡有input層、hidden層和output層,層與層之間的權重比較直觀;但在LSTM中通過公式才能找到對應的權重,和圖示中的邊並不是一一對應,下面我將LSTM的單個Block中需要更新的權重在圖示上標示了出來:

權重

為了方便起見,這里需要申明的是:我們僅考慮上一時刻的Cell僅和當前時刻的Cell和三個Gate相連。

2. Cell Output的梯度

首先我們計算每一個輸出類別的梯度: 

δtk========(x,z)atk(Kk=1zklnyk)atkk=1Kzklnykatkk=1Kzkykykatkk=1Kzkyk(ykδkkykyk)k=1Kzkykykδkk+k=1Kzkykykykzk+ykk=1Kzkykzk

 

也即每一個輸出類別的梯度僅和其預測值和真實值相關,這樣對於Cell Output的梯度則可以通過鏈式求導法則推導出來:

 

ϵtc=(x,z)btc=k=1K(x,z)atkatkbtc=k=1Kδtkωck

 

由於Output還可以連接下一個時刻的一個Cell、三個Gate,那么下一個時刻的一個Cell、三個Gate的梯度則可以傳遞回當前時刻Output,所以在論文中存在額外項Gg=1ωcgδt+1g,為簡便起見,公式和圖示中未包含。

Cell Output

3. Output Gate的梯度

根據鏈式求導法則,Output Gate的梯度可以由以下公式推導出來:

 

δtω=(x,z)atω=(x,z)btcbtcbtωbtωatω=ϵtch(stc)f(atw)

 

另外,由於單個Block內可以存在多個memory cell、一個Forget Gate、一個Input Gate和一個Output Gate,論文中將Output Gate的梯度寫成了f(atw)Cc=1ϵtch(stc),但推導過程一致。推導過程見下圖,說明梯度匯總到單個Gate中:

Output Gate

4. Cell的梯度

細心的讀者在這里會發現,Cell的計算結構和普遍的神經網絡不太一樣,讓我們首先來回顧一下Cell部分的Forward計算過程:

 

atc=i=1Iωicxti

 

 

stc=btϕst1c+btιg(atc)

 

輸入數據貢獻給atc,而Cell同時能夠接受Input Gate和Forget Gate的輸入。

這樣梯度就直接從Cell向下傳遞:

 

δtc=(x,z)atc=(x,z)stcstcatc=(x,z)stcbtιg(atc)

 

在這里,我們定義States,由於Cell的梯度可以由以下幾個計算單元傳遞回來:

  1. 當前時刻的Cell Output
  2. 下一個時刻的Cell
  3. 下一個時刻的Input Gate
  4. 下一個時刻的Output Gate

那么States可以這樣求解,上面1~4個能夠回傳梯度的計算單元和下面公式中一一對應: 

ϵts====(x,z)stct(x,z)stc+t+1(x,z)st+1cst+1cstc+t+1(x,z)at+1ιat+1ιstc+t+1(x,z)at+1ϕat+1ϕstc((x,z)atwatwstc+(x,z)btcbtcstc)+bt+1ϕϵt+1s+ωcιδt+1ι+ωcϕδt+1ϕδtωωcω+ϵtcbtωh(stc)+bt+1ϕϵt+1s+ωcιδt+1ι+ωcϕδt+1ϕ

 

那么: 

δtc=ϵtsbtιg(atc)

 

Cell

細心的讀者會發現,論文中(x,z)btc並沒有求和,這里作者持保留態度,應該存在求和項。

同時由於Cell可以連接到下一個時刻的Forget Gate、Output Gate和Input Gate,那么下一時刻的這三個Gate則可以將梯度傳播回來,所以在論文中我們會發現ϵts擁有這三項:bt+1ϕϵt+1sωclδt+1ιωcϕδt+1ϕ

5. Forget Gate的梯度

Forget Gate的梯度計算就比較簡單明了:

 

δtϕ=(x,z)atϕ=(x,z)stcstcbtϕbtϕatϕ=ϵtsst1cf(atϕ)

 

Forget Gate

另外,由於單個Block內可以存在多個memory cell、一個Forget Gate、一個Input Gate和一個Output Gate,論文中將Forget Gate的梯度寫成了f(atϕ)Cc=1st1cϵts,但推導過程一致,說明梯度匯總到單個Gate中。

6. Input Gate的梯度

Input Gate的梯度計算如下:

 

δtι=(x,z)atι=(x,z)stcstcbtιbtιatι=ϵtsg(atc)f(atι)

 

Input Gate

另外,由於單個Block內可以存在多個memory cell、一個Forget Gate、一個Input Gate和一個Output Gate,論文中將Input Gate的梯度寫成了f(atι)Cc=1g(atc)ϵts,但推導過程一致,說明梯度匯總到單個Gate中。

7. 小結

至此,所有的梯度求解已經結束,同樣我們將這個Backward Pass的所有公式列出來:

小結

剩下的事情即利用梯度去更新每個權重: 

Δωn=mΔωn1αωn

 

其中mΔωn1為上一次權重的更新值,且m[0,1];而ωn即上面我們求到的每一個梯度。

例如每次更新ωiϕΔ量即: 

Δωniϕ=mΔωn1iϕαxiδtϕ

 

其中δtϕ即Forget Gate的梯度。

三、總結

以上就是LSTM中的前向和反向傳播的公式推導,在這里作者僅以最簡單的單個Cell的場景進行示例。

在實際工程實踐中,常常會涉及到同一時刻多個Cell且互相之間的Gate存在連接,同時上一個時刻或下一個時刻的Cell和三個Gate之間同樣存在復雜的連接關系。

但如果讀者能夠明晰上述的推導過程,那么無論多復雜都能夠迎刃而解了。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM