一、關於RNN的梯度消失&爆炸問題
1. 關於RNN結構
循環神經網絡RNN(Recurrent Neural Network)是用於處理序列數據的一種神經網絡,已經在自然語言處理中被廣泛應用。下圖為經典RNN結構:
2. 關於RNN前向傳播
RNN前向傳導公式:
其中: St : t 時刻的隱含層狀態值
Ot : t 時刻的輸出值
① 是隱含層計算公式,U是輸入x的權重矩陣,W是時刻t-1的狀態值
St-1作為輸入的權重矩陣,Φ是激活函數。
② 是輸出層計算公式,V是輸出層的權重矩陣,f是激活函數。
損失函數(loss function)采用交叉熵( Ot 是t時刻預測輸出,
是 t 時刻正確的輸出)
那么對於一次訓練任務中,損失函數:, T 是序列總長度。
假設初始狀態St為0,t=3 有三段時間序列時,由 ① 帶入②可得到
t1、t2、t3 各個狀態和輸出
3. 關於RNN反向傳播
BPTT(back-propagation through time)算法是針對循層的訓練算法,它的基本原理和BP算法一樣。其算法本質還是梯度下降法,那么該算法的關鍵就是計算各個參數的梯度,對於RNN來說參數有 U、W、V。
反向傳播
可以簡寫成:
觀察③④⑤式,可知,對於 V 求偏導不存在依賴問題;但是對於 W、U 求偏導的時候,由於時間序列長度,存在長期依賴的情況。主要原因可由 t=1、2、3 的情況觀察得 , St會隨着時間序列向前傳播,同時St是 U、W 的函數。
前面得出的求偏導公式⑥,取其中累乘的部分出來,其中激活函數 Φ 通常是:tanh 則
由上圖可知當激活函數是tanh函數時,tanh函數的導數最大值為1,又不可能一直都取1這種情況,而且這種情況很少出現,那么也就是說,大部分都是小於1的數在做累乘,若當t很大的時候,趨向0,舉個例子:0.850=0.00001427247也已經接近0了,這是RNN中梯度消失的原因。
但要注意:RNN 中總的梯度是不會消失的。即便梯度越傳越弱,那也只是遠距離的梯度消失,由於近距離的梯度不會消失,所有梯度之和便不會消失。RNN 所謂梯度消失的真正含義是,梯度被近距離梯度主導,導致模型難以學到遠距離的依賴關系。
再看⑦部分:
tanh’,還需要網絡參數 W ,如果參數 W 中的值太大,隨着序列長度同樣存在長期依賴的情況,那么產生問題就是梯度爆炸,而不是梯度消失了,在平時運用中,RNN比較深,使得梯度爆炸或者梯度消失問題會比較明顯。
二、LSTM緩解梯度消失
至於怎么避免這種現象,讓我在看看 梯度消失和爆炸的根本原因就是
這一坨,要消除這種情況就需要把這一坨在求偏導的過程中去掉,至於怎么去掉,一種辦法就是使
另一種辦法就是使
。其實這就是LSTM做的事情。
我們來看看LSTM的內部結構,包含了四個門層結構:
引用自 Stanford CS231n slides
LSTM相信很多人看過這個:[譯] 理解 LSTM 網絡,但是我發現cs231n的公式更加簡潔,把四個門層結構的權重參數合成一個W。
求導過程比較復雜,我們先看一下這一項:
和前面一樣,我們來求一下 ,這里注意
,
和
都是
的復合函數:
后面的我們就不管了,展開求導太麻煩了,第一項是什么!大聲告訴我!
是forget gate的輸出值,1表示完全保留舊狀態,0表示完全舍棄舊狀態,那如果我們把
設置成1或者是接近於1,那
這一項就有妥妥的梯度了。
因此LSTM是靠着cell結構來保留梯度,forget gate控制了對過去信息的保留程度,如果gate選擇保留舊狀態,那么梯度就會接近於1,可以緩解梯度消失問題。這里說緩解,是因為LSTM只是在 到
這條路上解決梯度消失問題,而其他路依然存在梯度消失問題。
而且forget gate解決了RNN中的長期依賴問題,不管網絡多深,也可以記住之前的信息。
另外,LSTM可以緩解梯度消失,但是梯度爆炸並不能解決,但實際上前面也講過,梯度爆炸不是什么大問題(閾值裁剪)。
三、GRU緩解梯度消失
LSTM內部結構比較復雜,因此衍生了簡化版GRU,把LSTM的input gate和forget gate整合成一個update gate,也是通過gate機制來控制梯度:
我們還是來求一下 ,我們可以得到:
,那一串省略號我們還是不管,我們依然可以通過控制
來控制梯度。
所以,我們現在可以看到,LSTM系列都是通過gate機制來緩解梯度消失問題的。