LSTM網絡(Long Short-Term Memory )


本文基於前兩篇 1. 多層感知機及其BP算法(Multi-Layer Perceptron) 與 2. 遞歸神經網絡(Recurrent Neural Networks,RNN)

RNN 有一個致命的缺陷,傳統的 MLP 也有這個缺陷,看這個缺陷之前,先祭出 RNN 的 反向傳導公式與 MLP 的反向傳導公式:

\[RNN : \ \delta_h^t = f'(a_h^t) \left (\sum_k\delta_k^tw_{hk} + \sum_{h'} \delta^{t+1}_{h'}w_{hh'}   \right )\]

\[MLP : \ \delta_h =   f'(a_h) \sum_{h'=1}^{h_{l+1}} w_{hh'}\delta_{h'}\]

注意,殘差在時間維度上反向傳遞時,每經過一個時刻,就會導致信號的大幅度衰減,為啥呢,就是因為這個非線性激活函數 $f$ ,一般這個函數的形狀如下圖:

如上圖所示,激活函數 $f$ 在在紅線以外的斜度變化很小,所以函數 $f$ 的導數 $f'$ 取值很小,而經過以上列出的殘差反向傳遞公式可以得出,每經過一個時刻,衰減 $f'$ 的數量級,所以經過多個時刻會導致時間維度上梯度呈指數級的衰減,即此刻的反饋信號不能影響太遙遠的過去 。在 MLP 中,如果網絡太深,這種梯度衰減會導致網絡的前幾層的殘差趨近於 0 ,這意味着前面的隱藏層中的神經元學習速度要慢於后面的隱藏層。無論 RNN 還是 MLP ,對參數的導數都是這種形式(RNN需要在時間維度上求和):

\[\frac{\partial O}{\partial w_{ij}} = \frac{\partial O}{\partial a_{j}} \frac{\partial a_j}{\partial w_{ij}} = \delta_jb_i\]

殘差衰減的太小導致參數的導數太小 ,從而梯度下降法中前幾層的參數只有微乎其微的變化,對於深層的 MLP 由於梯度衰減導致效果不如淺層的網絡,對於 RNN 就會導致不能處理長期依賴的問題,雖然 RNN 理論上可以處理任意長度的序列,但實習應用中,RNN 很難處理長度超過 10 的序列。這種現象叫做 gradient vanishing/exploding 。下圖形象的描繪了這種現象:

 

對於 $t=1$ 的輸入,隨着時間的推移,對於 $t >1$ 時刻的產生的影響會越來越小,由圖中的顏色的深淺代表信號的大小。這種衰減會導致 RNN 無法處理長期依賴,舉個例子,當有一句話“I grew up in France … I speak fluent French.”  在預測該人會將一口流利的            語時,會依賴之前他的長大的環境,而序列中兩個詞語的間隔太大,這便是所說的長期依賴問題。  

對於長期以來問題,反向傳播時,梯度也會呈指數倍數的衰減,這種衰減現象導致 RNN 無法處理長期依賴,為了克服 RNN 的這種缺陷,學者們研究了眾多方法,其中 Long Short-Term Memory 表現最為出色。使用 LSTM 模塊后,當誤差從輸出層反向傳播回來時,可以使用模塊的記憶元記下來。所以 LSTM 可以記住比較長時間內的信息。

初始的 LSTM (Hochreiter and Schmidhuber ,1997)網絡結構類似於 RNN ,只是把 RNN 的隱層換成了存儲塊(memeory block),如下圖左所示, memory block 中用記憶單元 (memory cell)來保存信息(類似於 RNN 中的隱藏節點),,每個存儲塊包含一個或多個memory cell ,如下圖左中間的 “$\oslash$” 節點如下圖所示,藍色虛線為一條遞歸自連接的權值為 1 的邊,保證梯度沿時間傳播時不會損失,在時刻 $t$  的輸入如下圖的 $g^t$ 所示,除接受本時刻的輸入 $x^t$ 外,還接受上一時刻的輸出 $h^{t-1}$ ,並且經過非線性激活函數 $\sigma$ ,LSTM 並不是接納所有輸入 $g^t$ ,而是在網絡中加入兩個門,輸入門(input gate)、輸出門(output gate), 門的節點數目與 memory cell 一一對應, input gate 如下圖的 $i^t$ 所示,跟輸入層一樣,接受 $x^t$ 與 $h^{t-1}$ ,經過  $\sigma$ 后產生一個 0-1 向量(維度即為 memory cell 或者 input gate 的維度),0 代表關閉 、1 代表開啟,這樣來對輸入進行控制,下圖左中的 “$\prod$ ” 表示 input gate 的輸出  $i^t$ 與本時刻輸入 $g^t$ 的輸出逐元素相乘,即 input gate 會對輸入進行過濾 ,然后存放到 memory cell 里,現在memory cell 里既有上一時刻 $t-1$ 的狀態,又添加了時刻 $t$ 的狀態, 即

\[s^t = g^t \odot i^t + s^{t-1}\]

memory cell 有一個循環自連接的權值為 1 的邊,這樣 memory cell state 中梯度沿時間傳播時不會導致不會 vanishing 或者 exploding ,output gate 類似於 input gate 會產生一個 0-1 向量來控制 memory cell 到輸出層的輸出。即

\[ v^t = s^t \odot o^t  \]

后來為了增強 LSTM 的處理能力, Gers et al. [2000] 引入了 forget gate, LSTM 的網絡結構變成了如上圖右所示,也就是說 forget gate 取代了之前權值為 1 的邊,經過這樣的改進,memory cell 可以遺忘之前的內容,只需將 memory cell 中的內容與 forget gate 逐元素相乘即可, forget gate  與 input/output gate 一樣,接受  $x^t$ 與 $h^{t-1}$ 作為輸入,現在的 LSTM memory cell 的更新公式為

\[s^t = g^t \odot i^t + f^t \odot s^{t-1}\]

Gers & Schmidhuber [2000] 在以上結構的基礎上又提出了 peephole connections ,將 $t-1$ 時刻沒有經過 output gate 處理過的 memory cell 狀態送到時刻   $t$ 作為 input gate 和 output gate 的輸入,即三個門的輸入增加了了  $s^ {t-1}$ ,現在流行的網絡結構如下圖所示:

三個門協作使得  LSTM 存儲塊可以存取長期信息,比如說只要輸入門保持關閉,記憶單元的信息就不會被之后時刻的輸入所覆蓋。下圖形象的描述了這個過程,在 Hidden Layer 中每個節點都是一個 memeory block ,每個 memeory block 的包含三個門,左邊為 forget gate ,下邊尾 input gate ,上邊為 output gate ,門有打開關閉兩種狀態,分別由 "$\bigcirc $" 與 "$-$" 來表示。可見對於時刻 1 的輸入,只要之后時刻的 input gate 保持關閉,forget gate 保持打開,便可以在不影響 memory cell 的情況下隨時開啟 output gate 來獲得 memory cell 的內容。對於梯度反向傳播時,同樣可以通過這種方式來保持殘差不會過度衰減。

接下來本文所涉及的將是詳細 LSTM 的 BP 過程,網絡結構采用的是 Gers & Schmidhuber [2000]所提出的 LSTM 結構,值得注意的是,這里對 memory cell 的輸出增加了激活函數 $h$ , 之前的 $h$ 可以理解為線性的,這里先聲明一些符號表示: $w_{ij}$ 表示 單元 $i$ 到單元 $j$ 的權值,$a_j^t$ 表示時刻 $t$ 單元  $j$ 的輸入,$b_j^t = f(a_j^t)$ 表示對單元 $j$ 的輸入做非線性映射,$\iota$  、 $\phi$  、 $\omega$ 分別代表 input gate 、forget gate、 output gate ,$C$ 用來表示 memroy cell 的數量,  $s^t_c$ 表示 memeory cell $c$ 在時刻  $t$ 的狀態, $f$ 表示門的激活函數(通常為 $sigmod$ 函數), $g$ 與 $h$ 分別表示 memory cell 輸入與輸出的激活函數,用 $I$ 表示輸入層大小, $H$ 表示隱層 memory cell 的大小(其實 $H = C$,這里只是為了方便表示,因為 memory cell 的輸出   $b_h^t$ 會往下個時刻傳輸,其權值可表示為 $w_{h.}$ , memrory cell 本身的權值可用  $w_ {c.}$ 來表示) , $K$ 表示輸出層的大小。 待序列為 $t = 1...T$ ,時刻 $t$ 的輸入為 $x^t$ ,注意 $b^0 = 0$ , 殘差 $\delta ^{T+1} = 0$ 。

  • forget gate : 在 LSTM 的 memory block 中,只有上一時刻 memory cell 的輸出 $ b_h^t$ 會傳送到本單元 ,其他數據比如 memory cell state 或者 memory cell  input 等只在單元內部可見,forget gate 是用來控制上個時刻的 memory cell state 即 $s^{t-1}$ :

\[a^t_{\phi } = \sum_iw_{i \phi } x_i^t + \sum_hw_{h \phi}b_{h}^{t-1}+ \sum_cw_{c\phi}s_c^{t-1} \]

\[b_{\phi }^t = f(a_{\phi}^t)\]

  • input gate : 這個門控制當前時刻 memory cell state 的輸入:

\[a^t_{\iota } = \sum_iw_{i \iota } x_i^t + \sum_hw_{h \iota}b_{h}^{t-1}+ \sum_cw_{c\iota}s_c^{t-1} \]

\[b_{\iota }^t = f(a_{\iota}^t)\]

  • memory cell : 對於時刻 $t-1 \rightarrow  t$ , memroy cell 的信息是這樣變化的 ,首先對 $t-1$  時刻 memory cell 的狀態用 forget gate 進行過濾($b_{\phi}^t s_c^{t-1}$),看要遺忘或者保存哪些信息,然后獲取現在時刻 $t$ 的輸入信息($g(a_c^t)$),用 input gate 進行過濾 ($b_{\iota }^tg(a_c^t)$),過濾完后相加就完成了$t-1 \rightarrow  t$ 時刻的 memory cell 狀態的轉變 :

\[a^t_c = \sum_i w_{ic} x_i^t + \sum_h w_{hc}b_{h}^{t-1} \]
\[s_c^t = b_{\phi}^t s_c^{t-1} + b_{\iota }^tg(a_c^t)\]

  • output gate : 這個門會控制 cell state 的輸出:

\[a^t_{\omega } = \sum_iw_{i \omega } x_i^t + \sum_hw_{h \omega }b_{h}^{t-1}+ \sum_cw_{c\omega }s_c^{t} \]

\[b_{\omega }^t = f(a_{\omega }^t)\]

  • memory cell output : 計算 memory cell 的輸出 ,由 output gate 控制,這個輸出也會作為下一時刻整個 memory block 的輸入(類似於 RNN 的隱層傳遞)

\[b_c^t = b_{\omega}^t h(s_c^t)\]

接下來便是殘差的反向傳導,對於輸出層,同 RNN 一般是 $softmax$ 或者 $logistic$ ,這里首先定義:

\[\epsilon_c^t=\frac{\partial O}{\partial b_c^t}=\sum_k\frac{\partial O}{\partial a_k^t} \frac{\partial a_k^t}{\partial b_c^t}+\sum_{h}\frac{\partial O}{\partial a_h^t} \frac{\partial a_h^t}{\partial b_c^t}=\sum_{k} w_{ck}\delta_k^t+\sum_hw_{ch}\delta_h^{t+1} \] 

接下來,殘差傳導至 output gate :

\[\delta_\omega^t=\frac{\partial O}{\partial a_\omega^t}=\sum_c \frac{\partial O}{\partial b_c^t}\frac{\partial b_c^t}{\partial b_\omega^t}\frac{\partial b_\omega^t}{\partial a_\omega^t} =f'(a_\omega^t)\sum_c \epsilon_c^t h(s_c^t) \]

現在再定義一個輔助變量:

\[\epsilon_s^t=\frac{\partial \mathcal{L}}{\partial s_c^t}
=\frac{\partial O}{\partial b_c^t} \frac{\partial b_c^t}{\partial h(s_c^t)} \frac{\partial h(s_c^t)}{\partial s_c^t}
+\frac{\partial O}{\partial s_c^{t+1}} \frac{\partial s_c^{t+1}}{\partial s_c^t}
+\frac{\partial O}{\partial a_\omega^t} \frac{\partial a_\omega^t}{\partial s_c^t}
+\frac{\partial O}{\partial a_\iota^t} \frac{\partial a_\iota^t}{\partial s_c^t}
+\frac{\partial O}{\partial a_\phi^t} \frac{\partial a_\phi^t}{\partial s_c^t} \Rightarrow\]

\[\epsilon_s^t=b_w^th'(s_c^t)\epsilon_c^t+b_\phi^{t+1}\epsilon_s^{t+1}+w_{c\omega}\delta_\omega^t+w_{c\iota}\delta_\iota^{t+1} +w_{c\phi}\delta_\phi^{t+1}\]

這就是 bp 中最復雜的公式了,依次解釋下各項。首先,看memory block的圖,查看該單元指向輸出單元的所有路徑,沒有一條不同的路徑就代表一項;然后運用鏈式法則展開每個路徑;就得到后向傳播中該單元的梯度$\delta$。這個輔助變量中可以看到后三項來自於cell state 對三個 gate 的監督,即 peephole ,所以若不采用 peephole 的方式就可以省略。第二項來自於下一時刻的狀態誤差,其實是 forget gate 對當前狀態的調節作用。

接下來誤差傳播到 memory cell :

\[\delta_c^t =\frac{\partial O}{\partial a_c^t}=\frac{\partial O}{\partial s_c^t}\frac{\partial s_c^t}{\partial g(a_c^t)}\frac{\partial g(a_c^t)}{\partial a_c^t}=\epsilon_c^t b_\iota^t g'(a_c^t)\]

最后分別傳導至 forget gate $\phi$ 與 輸入門 $\iota$:

\[\delta_\phi^t =\frac{\partial O}{\partial a_\phi^t}=\sum_c\frac{\partial O}{\partial s_c^t}\frac{\partial s_c^t}{\partial b_\phi^t}\frac{\partial b_\phi^t}{\partial a_\phi^t}=f'(a_\phi^t)\sum_c s_c^{t-1}\epsilon_s^t \]

\[\delta_\iota^t =\frac{\partial O}{\partial a_\iota^t}=\sum_c\frac{\partial O}{\partial s_c^t}\frac{\partial s_c^t}{\partial b_\iota^t}\frac{\partial b_\iota^t}{\partial a_\iota^t}=f'(a_\iota^t)\sum_c g(a_c^{t-1})\epsilon_s^t\]

 殘差傳導完成后,直接用殘差對權重 $w_{ij}$ 進行求導即可 (這里 $b_i^t$ 可代表輸入 $x_i^t$、$b_h^{t-1}$、$s_c^{t-1}$):

\[\frac{\partial O}{\partial w_{ij}} = \sum_t \frac{\partial O}{\partial a_j^t}\frac{\partial a_j^t}{\partial w_{ij}} = \sum_t \delta_j^tb_i^t\]

參考:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

     Supervised Sequence Labelling with Recurrent Neural Networks

     http://ethancao.cn/2015/12/07/learning-LSTM.html 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM