RNN梯度消失和爆炸的原因
經典的RNN結構如下圖所示:

假設我們的時間序列只有三段, 為給定值,神經元沒有激活函數,則RNN最簡單的前向傳播過程如下:
假設在t=3時刻,損失函數為 。
則對於一次訓練任務的損失函數為 ,即每一時刻損失值的累加。
使用隨機梯度下降法訓練RNN其實就是對 、
、
以及
求偏導,並不斷調整它們以使L盡可能達到最小的過程。
現在假設我們我們的時間序列只有三段,t1,t2,t3。
我們只對t3時刻的 求偏導(其他時刻類似):
可以看出對於 求偏導並沒有長期依賴,但是對於
求偏導,會隨着時間序列產生長期依賴。因為
隨着時間序列向前傳播,而
又是
的函數。
根據上述求偏導的過程,我們可以得出任意時刻對 求偏導的公式:
任意時刻對 求偏導的公式同上。
如果加上激活函數, ,
則 =
激活函數tanh和它的導數圖像如下。

由上圖可以看出 ,對於訓練過程大部分情況下tanh的導數是小於1的,因為很少情況下會出現
,如果
也是一個大於0小於1的值,則當t很大時
,就會趨近於0,和
趨近與0是一個道理。同理當
很大時
就會趨近於無窮,這就是RNN中梯度消失和爆炸的原因。
至於怎么避免這種現象,讓我在看看 梯度消失和爆炸的根本原因就是
這一坨,要消除這種情況就需要把這一坨在求偏導的過程中去掉,至於怎么去掉,一種辦法就是使
另一種辦法就是使
。其實這就是LSTM做的事情。
LSTM如何解決梯度消失問題
先上一張LSTM的經典圖:

至於這張圖的詳細介紹請參考:Understanding LSTM Networks
下面假設你已經閱讀過Understanding LSTM Networks這篇文章了,並且了解了LSTM的組成結構。
RNN梯度消失和爆炸的原因這篇文章中提到的RNN結構可以抽象成下面這幅圖:

而LSTM可以抽象成這樣:

三個×分別代表的就是forget gate,input gate,output gate,而我認為LSTM最關鍵的就是forget gate這個部件。這三個gate是如何控制流入流出的呢,其實就是通過下面 三個函數來控制,因為
(代表sigmoid函數) 的值是介於0到1之間的,剛好用趨近於0時表示流入不能通過gate,趨近於1時表示流入可以通過gate。
當前的狀態 類似與傳統RNN
。將LSTM的狀態表達式展開后得:
如果加上激活函數,
RNN梯度消失和爆炸的原因這篇文章中傳統RNN求偏導的過程包含
對於LSTM同樣也包含這樣的一項,但是在LSTM中
假設 ,則
的函數圖像如下圖所示:

可以看到該函數值基本上不是0就是1。
傳統RNN的求偏導過程:
如果在LSTM中上式可能就會變成:
因為 ,這樣就解決了傳統RNN中梯度消失的問題。
https://zhuanlan.zhihu.com/p/28687529