為什么LSTM可以防止梯度消失?從反向傳播的角度分析


為什么LSTM可以防止梯度消失?從反向傳播的角度分析
 
 
 
 
LSTM:溫和的巨人
 
相比於RNN,雖然LSTM(或者GRU)看上去復雜而臃腫,但是LSTM(或者GRU)在實際中的效果是非常好的,它可以解決RNN中出現的梯度消失的問題。
 
梯度消失是指,在反向傳播時,梯度值隨着反向傳播呈指數下降,最終造成的影響是越靠近輸入的層梯度值越接近0,這些層因此無法得到有效的訓練。對於RNN,這意味着無法跟蹤任何長期依賴關系。 這是一種麻煩,因為RNN的全部意義在於跟蹤長期依賴關系。
 
接下來介紹為什么LSTM(及其相關的模型)可以解決梯度消失的問題。
 
下面先介紹LSTM相關的符號。

 

 LSTM的公式如下(省略bias)

 

 

梯度消失的情況
 
為了理解LSTM為什么有幫助,我們需要先理解普通RNN( vanilla RNNs)中出現的問題。在普通的RNN中,隱含層向量和輸出的計算方式如下:

 

 要通過時間進行反向傳播 backpropagation through time)來訓練RNN,我們需要計算E關於的梯度。 總誤差梯度等於每一時間步的誤差梯度之和。對於時間t,我們可以使用鏈式法則來推導出誤差梯度,如下:

 

 上面的式子中,的具體形式如下:

 

 

關於 的偏導數如下:

 

 

其中,diag函數將一個向量轉換為對角矩陣。
 
因此,如果我們通過時間步t來進行反向傳播,梯度表示如下:

 

 

參考這篇文章( On the difficulty of training Recurrent Neural Networks ),如果矩陣 的主特征值( dominant eigenvalue)大於1,那么就會產生梯度爆炸(gradient explodes);如果小於1,那么就會產生梯度消失(gradient vanishes)。注意到 的值總是小於1,因此如果 的值太小,將不可避免的會造成梯度值變成0;如果 的值很大,那么導數/梯度就會變得很大。在實際中,梯度消失更加常見,因此我們更關注於梯度消失問題。
 
導數 可以告訴我們當我們改變時刻l(小寫的L)的隱層狀態時,時刻k的隱層狀態將會改變多少。根據上面的數學公式,梯度消失的意思是前面隱藏層( earlier hidden states)將對后面的隱藏層(later hidden states)不產生影響,這意味着沒有學到長期依賴關系(no long term dependencies are learned)。具體的證明可以參考 原始的LSTM文章上面提到的那篇文章

 

使用LSTM來防止梯度消失
 
正如上面提到,造成梯度消失的最大原因就是我們需要計算遞歸導數 ,我們如果可以解決這個問題,那么我們就可以學到長期依賴關系(long term dependencies)
 
針對這個問題,最原始的LSTM是這樣解決的:使得遞歸導數(recursive derivative)的值為常量。在這種情況下,梯度就不會消失或者爆炸。該如何實現這一點呢?LSTM引進了一個單獨的cell state 。在最原始的1997年版本的LSTM, 的值取決於前一個cell state的值和按input gate加權的更新項(使用input gate的motivation可以參考 這篇文章),具體公式如下:

 

 上面的公式效果並不好,原因是cell state可能會增長得無法控制。為了防止這個無限增長,引入了forget gate,公式如下:

 

 

 
一個常見的誤解。LSTM為什么可以解決梯度消失的問題,大多數解釋是在上述的更新公式下,遞歸導數(recursive derivative)的值等於1(原始的LSTM)或者值等於f(改進后的LSTM)。其中一個容易忘記的是,f、i和 都是關於 的函數,因此我們在計算梯度時必須將它們考慮在內。

接下來看看完整的LSTM的梯度。上面我們提到遞歸導數是造成梯度消失的主要原因,因此我們來解釋一下完整的導數。通過鏈式求導法則,我們可以得到

 

 上述求導具體可以寫為:

 

 

現在,如果我們要反向傳播k個時間步,我們只要簡單的將上述公式連乘k次就行。這與普通的RNN有很大的區別。對於普通的RNN,的最終要么總是大於1,要么總是在[0, 1]范圍內,這將導致梯度消失或者梯度爆炸。而對於LSTM,在任何時間步,該值可以大於1,或者在[0, 1]范圍內。因此,如果我們延伸到無窮的時間步,最終並不會收斂到0或者無窮。如果開始收斂到0,那么可以總是設置的值(或者其他gate的值)更高一些,使得的值接近1,從而防止了梯度消失(或者至少是,防止梯度不會那么快消失)。另外一個很重要的事情是,的值是網絡學習到的(根據當前的輸入和隱藏層)。因此,在這種情況下,網絡會學會決定什么時候讓梯度消失,什么時候保持梯度,都可以通過設置gate的值來決定。

這看起來很神奇,但實際上如下兩個原因:
  • 為cell state的更新函數給出了一個更加“表現良好”的導數( The additive update function for the cell state gives a derivative thats much more ‘well behaved’
  • 門控函數(gating function)允許網絡決定梯度消失多少,並且可以在每個時間步長取不同的值。它們所取的值是從當前輸入和隱藏狀態學習到的。( The gating functions allow the network to decide how much the gradient vanishes, and can take on different values at each time step. The values that they take on are learned functions of the current input and hidden state.
 
以上就是LSTM解決梯度消失的本質。

 

附:
  1. recursive partial derivative 是一個雅可比矩陣( Jacobian  matrix)。
  2. 為了直觀地理解遞歸權值矩陣的特征值的重要性,可以參考這篇文章
  3. 對於LSTM的遺忘門(forget gate),遞歸導數仍然是許多0和1之間的數的和,然而在實踐中,與RNN相比,這不是一個很大的問題。其中一個原因是我們的網絡可以直接控制f的值。如果需要記住一些內容,網絡可以很容易得將f取值高一點(如0.95左右)。因此,與tanh的導數值相比,這些值得收縮速度要慢的多。
  4. 為了完成完整的LSTM的推導,其實還有很多細節需要完成。本文不再贅述,感興趣的可以參考這篇文章 PhD thesis of Alex Graves
 
 
其他相關的文章鏈接:
 
RNN梯度消失和爆炸的原因  https://zhuanlan.zhihu.com/p/28687529
為什么相比於RNN,LSTM在梯度消失上表現更好?  https://www.zhihu.com/question/44895610
Why LSTMs Stop Your Gradients From Vanishing:A View from the Backwards Pass  https://weberna.github.io/blog/2017/11/15/LSTM-Vanishing-Gradients.html

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM