RNN梯度消失和爆炸的原因 以及 LSTM如何解決梯度消失問題


RNN梯度消失和爆炸的原因

經典的RNN結構如下圖所示:



假設我們的時間序列只有三段, S_{0} 為給定值,神經元沒有激活函數,則RNN最簡單的前向傳播過程如下:

S_{1}=W_{x}X_{1}+W_{s}S_{0}+b_{1}O_{1}=W_{o}S_{1}+b_{2}

S_{2}=W_{x}X_{2}+W_{s}S_{1}+b_{1}O_{2}=W_{o}S_{2}+b_{2}

S_{3}=W_{x}X_{3}+W_{s}S_{2}+b_{1}O_{3}=W_{o}S_{3}+b_{2}

假設在t=3時刻,損失函數為 L_{3}=\frac{1}{2}(Y_{3}-O_{3})^{2} 

則對於一次訓練任務的損失函數為 L=\sum_{t=0}^{T}{L_{t}} ,即每一時刻損失值的累加。

使用隨機梯度下降法訓練RNN其實就是對 W_{x}  W_{s}  W_{o} 以及 b_{1}b_{2} 求偏導,並不斷調整它們以使L盡可能達到最小的過程。

現在假設我們我們的時間序列只有三段,t1,t2,t3。

我們只對t3時刻的 W_{x}、W_{s}、W_{0} 求偏導(其他時刻類似):

\frac{\partial{L_{3}}}{\partial{W_{0}}}=\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{W_{o}}}

\frac{\partial{L_{3}}}{\partial{W_{x}}}=\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{W_{x}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{S_{2}}}\frac{\partial{S_{2}}}{\partial{W_{x}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{S_{2}}}\frac{\partial{S_{2}}}{\partial{S_{1}}}\frac{\partial{S_{1}}}{\partial{W_{x}}}

\frac{\partial{L_{3}}}{\partial{W_{s}}}=\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{W_{s}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{S_{2}}}\frac{\partial{S_{2}}}{\partial{W_{s}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{S_{2}}}\frac{\partial{S_{2}}}{\partial{S_{1}}}\frac{\partial{S_{1}}}{\partial{W_{s}}}

可以看出對於 W_{0} 求偏導並沒有長期依賴,但是對於 W_{x}、W_{s} 求偏導,會隨着時間序列產生長期依賴。因為 S_{t} 隨着時間序列向前傳播,而 S_{t} 又是 W_{x}、W_{s}的函數。

根據上述求偏導的過程,我們可以得出任意時刻對 W_{x}、W_{s} 求偏導的公式:

\frac{\partial{L_{t}}}{\partial{W_{x}}}=\sum_{k=0}^{t}{\frac{\partial{L_{t}}}{\partial{O_{t}}}\frac{\partial{O_{t}}}{\partial{S_{t}}}}(\prod_{j=k+1}^{t}{\frac{\partial{S_{j}}}{\partial{S_{j-1}}}})\frac{\partial{S_{k}}}{\partial{W_{x}}}

任意時刻對W_{s} 求偏導的公式同上。

如果加上激活函數, S_{j}=tanh(W_{x}X_{j}+W_{s}S_{j-1}+b_{1}) 

 \prod_{j=k+1}^{t}{\frac{\partial{S_{j}}}{\partial{S_{j-1}}}} = \prod_{j=k+1}^{t}{tanh^{'}}W_{s}

激活函數tanh和它的導數圖像如下。


由上圖可以看出 tanh^{'}\leq1 ,對於訓練過程大部分情況下tanh的導數是小於1的,因為很少情況下會出現W_{x}X_{j}+W_{s}S_{j-1}+b_{1}=0 ,如果 W_{s} 也是一個大於0小於1的值,則當t很大時 \prod_{j=k+1}^{t}{tanh^{'}}W_{s} ,就會趨近於0,和 0.01^{50} 趨近與0是一個道理。同理當 W_{s} 很大時 \prod_{j=k+1}^{t}{tanh^{'}}W_{s} 就會趨近於無窮,這就是RNN中梯度消失和爆炸的原因。

至於怎么避免這種現象,讓我在看看 \frac{\partial{L_{t}}}{\partial{W_{x}}}=\sum_{k=0}^{t}{\frac{\partial{L_{t}}}{\partial{O_{t}}}\frac{\partial{O_{t}}}{\partial{S_{t}}}}(\prod_{j=k+1}^{t}{\frac{\partial{S_{j}}}{\partial{S_{j-1}}}})\frac{\partial{S_{k}}}{\partial{W_{x}}} 梯度消失和爆炸的根本原因就是 \prod_{j=k+1}^{t}{\frac{\partial{S_{j}}}{\partial{S_{j-1}}}} 這一坨,要消除這種情況就需要把這一坨在求偏導的過程中去掉,至於怎么去掉,一種辦法就是使 {\frac{\partial{S_{j}}}{\partial{S_{j-1}}}}\approx1 另一種辦法就是使 {\frac{\partial{S_{j}}}{\partial{S_{j-1}}}}\approx0 。其實這就是LSTM做的事情。

LSTM如何解決梯度消失問題

先上一張LSTM的經典圖:


至於這張圖的詳細介紹請參考:Understanding LSTM Networks

下面假設你已經閱讀過Understanding LSTM Networks這篇文章了,並且了解了LSTM的組成結構。

RNN梯度消失和爆炸的原因這篇文章中提到的RNN結構可以抽象成下面這幅圖:


而LSTM可以抽象成這樣:


三個×分別代表的就是forget gate,input gate,output gate,而我認為LSTM最關鍵的就是forget gate這個部件。這三個gate是如何控制流入流出的呢,其實就是通過下面 f_{t},i_{t},o_{t} 三個函數來控制,因為 \sigma(x)(代表sigmoid函數) 的值是介於0到1之間的,剛好用趨近於0時表示流入不能通過gate,趨近於1時表示流入可以通過gate。

f_{t}=\sigma({W_{f}X_{t}}+b_{f})

i_{t}=\sigma({W_{i}X_{t}}+b_{i})

o_{i}=\sigma({W_{o}X_{t}}+b_{o})

當前的狀態 S_{t}=f_{t}S_{t-1}+i_{t}X_{t}類似與傳統RNN S_{t}=W_{s}S_{t-1}+W_{x}X_{t}+b_{1}。將LSTM的狀態表達式展開后得:

S_{t}=\sigma(W_{f}X_{t}+b_{f})S_{t-1}+\sigma(W_{i}X_{t}+b_{i})X_{t}

如果加上激活函數, S_{t}=tanh\left[\sigma(W_{f}X_{t}+b_{f})S_{t-1}+\sigma(W_{i}X_{t}+b_{i})X_{t}\right]

RNN梯度消失和爆炸的原因這篇文章中傳統RNN求偏導的過程包含 \prod_{j=k+1}^{t}\frac{\partial{S_{j}}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}{tanh{'}W_{s}}

對於LSTM同樣也包含這樣的一項,但是在LSTM中 \prod_{j=k+1}^{t}\frac{\partial{S_{j}}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}{tanh{’}\sigma({W_{f}X_{t}+b_{f}})}

假設 Z=tanh{'}(x)\sigma({y}) ,則 Z 的函數圖像如下圖所示:


可以看到該函數值基本上不是0就是1。

傳統RNN的求偏導過程:

\frac{\partial{L_{3}}}{\partial{W_{s}}}=\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{W_{s}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{S_{2}}}\frac{\partial{S_{2}}}{\partial{W_{s}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{S_{2}}}\frac{\partial{S_{2}}}{\partial{S_{1}}}\frac{\partial{S_{1}}}{\partial{W_{s}}}

如果在LSTM中上式可能就會變成:

\frac{\partial{L_{3}}}{\partial{W_{s}}}=\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{W_{s}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{2}}}{\partial{W_{s}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{1}}}{\partial{W_{s}}}

因為 \prod_{j=k+1}^{t}\frac{\partial{S_{j}}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}{tanh{’}\sigma({W_{f}X_{t}+b_{f}})}\approx0|1 ,這樣就解決了傳統RNN中梯度消失的問題。



來源:

 https://zhuanlan.zhihu.com/p/28687529







免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM