RNN梯度消失&爆炸原因解析與LSTM&GRU的對其改善


一、關於RNN的梯度消失&爆炸問題

1. 關於RNN結構

循環神經網絡RNN(Recurrent Neural Network)是用於處理序列數據的一種神經網絡,已經在自然語言處理中被廣泛應用。下圖為經典RNN結構:

2. 關於RNN前向傳播

RNN前向傳導公式:

其中:  St :  t 時刻的隱含層狀態值

Ot :  t 時刻的輸出值

① 是隱含層計算公式,U是輸入x的權重矩陣,W是時刻t-1的狀態值

St-1作為輸入的權重矩陣,Φ是激活函數。

② 是輸出層計算公式,V是輸出層的權重矩陣,f是激活函數。

損失函數(loss function)采用交叉熵( Ot 是t時刻預測輸出, 是 t 時刻正確的輸出) 

那么對於一次訓練任務中,損失函數:, T 是序列總長度。

假設初始狀態St為0,t=3 有三段時間序列時,由 ① 帶入②可得到 

t1、t2、t3 各個狀態和輸出

3. 關於RNN反向傳播

BPTT(back-propagation through time)算法是針對循層的訓練算法,它的基本原理和BP算法一樣。其算法本質還是梯度下降法,那么該算法的關鍵就是計算各個參數的梯度,對於RNN來說參數有 U、W、V。

反向傳播

可以簡寫成:

觀察③④⑤式,可知,對於 V 求偏導不存在依賴問題;但是對於 W、U 求偏導的時候,由於時間序列長度,存在長期依賴的情況。主要原因可由 t=1、2、3 的情況觀察得 , St會隨着時間序列向前傳播,同時St是 U、W 的函數。

前面得出的求偏導公式⑥,取其中累乘的部分出來,其中激活函數 Φ 通常是:tanh 則

由上圖可知當激活函數是tanh函數時,tanh函數的導數最大值為1,又不可能一直都取1這種情況,而且這種情況很少出現,那么也就是說,大部分都是小於1的數在做累乘,若當t很大的時候,趨向0,舉個例子:0.850=0.00001427247也已經接近0了,這是RNN中梯度消失的原因。

但要注意:RNN 中總的梯度是不會消失的。即便梯度越傳越弱,那也只是遠距離的梯度消失,由於近距離的梯度不會消失,所有梯度之和便不會消失。RNN 所謂梯度消失的真正含義是,梯度被近距離梯度主導,導致模型難以學到遠距離的依賴關系。

再看⑦部分:

tanh’,還需要網絡參數 W ,如果參數 W 中的值太大,隨着序列長度同樣存在長期依賴的情況,那么產生問題就是梯度爆炸,而不是梯度消失了,在平時運用中,RNN比較深,使得梯度爆炸或者梯度消失問題會比較明顯。

二、LSTM緩解梯度消失

至於怎么避免這種現象,讓我在看看 \frac{\partial{L_{t}}}{\partial{W_{x}}}=\sum_{k=0}^{t}{\frac{\partial{L_{t}}}{\partial{O_{t}}}\frac{\partial{O_{t}}}{\partial{S_{t}}}}(\prod_{j=k+1}^{t}{\frac{\partial{S_{j}}}{\partial{S_{j-1}}}})\frac{\partial{S_{k}}}{\partial{W_{x}}} 梯度消失和爆炸的根本原因就是 \prod_{j=k+1}^{t}{\frac{\partial{S_{j}}}{\partial{S_{j-1}}}} 這一坨,要消除這種情況就需要把這一坨在求偏導的過程中去掉,至於怎么去掉,一種辦法就是使 {\frac{\partial{S_{j}}}{\partial{S_{j-1}}}}\approx1 另一種辦法就是使 {\frac{\partial{S_{j}}}{\partial{S_{j-1}}}}\approx0 。其實這就是LSTM做的事情。

我們來看看LSTM的內部結構,包含了四個門層結構:

引用自 Stanford CS231n slides

LSTM相信很多人看過這個:[譯] 理解 LSTM 網絡,但是我發現cs231n的公式更加簡潔,把四個門層結構的權重參數合成一個W。

求導過程比較復雜,我們先看一下[公式]這一項:

[公式]

和前面一樣,我們來求一下[公式] ,這里注意[公式] ,[公式]和 [公式] 都是 [公式]的復合函數:

[公式]

后面的我們就不管了,展開求導太麻煩了,第一項[公式]是什么!大聲告訴我! [公式]是forget gate的輸出值,1表示完全保留舊狀態,0表示完全舍棄舊狀態,那如果我們把 [公式]設置成1或者是接近於1,那 [公式] 這一項就有妥妥的梯度了。

因此LSTM是靠着cell結構來保留梯度,forget gate控制了對過去信息的保留程度,如果gate選擇保留舊狀態,那么梯度就會接近於1,可以緩解梯度消失問題。這里說緩解,是因為LSTM只是在 [公式]到 [公式]這條路上解決梯度消失問題,而其他路依然存在梯度消失問題。

而且forget gate解決了RNN中的長期依賴問題,不管網絡多深,也可以記住之前的信息。

另外,LSTM可以緩解梯度消失,但是梯度爆炸並不能解決,但實際上前面也講過,梯度爆炸不是什么大問題(閾值裁剪)。

 三、GRU緩解梯度消失

LSTM內部結構比較復雜,因此衍生了簡化版GRU,把LSTM的input gate和forget gate整合成一個update gate,也是通過gate機制來控制梯度:

我們還是來求一下 [公式] ,我們可以得到:[公式] ,那一串省略號我們還是不管,我們依然可以通過控制 [公式] 來控制梯度。

所以,我們現在可以看到,LSTM系列都是通過gate機制來緩解梯度消失問題的。

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM