LSTM如何解決梯度消失或爆炸的?


from:https://zhuanlan.zhihu.com/p/44163528

哪些問題?

  • 梯度消失會導致我們的神經網絡中前面層的網絡權重無法得到更新,也就停止了學習。
  • 梯度爆炸會使得學習不穩定, 參數變化太大導致無法獲取最優參數。
  • 在深度多層感知機網絡中,梯度爆炸會導致網絡不穩定,最好的結果是無法從訓練數據中學習,最壞的結果是由於權重值為NaN而無法更新權重。
  • 在循環神經網絡(RNN)中,梯度爆炸會導致網絡不穩定,使得網絡無法從訓練數據中得到很好的學習,最好的結果是網絡不能在長輸入數據序列上學習。

3. 原因何在?

讓我們以一個很簡單的例子分析一下,這樣便於理解。

 

 

如上圖,是一個每層只有一個神經元的神經網絡,且每一層的激活函數為sigmoid,則有:

y_i = \sigma(z_i) = \sigma(w_ix_i + b_i)  ( \sigma 是sigmoid函數)。

我們根據反向傳播算法有:

 \frac{\delta C}{\delta b_1} = \frac{\delta C}{ \delta y_4} \frac{\delta y_4}{\delta z_4} \frac{\delta z_4}{\delta x_4} \frac{ \delta x_4}{\delta z_3} \frac{\delta z_3}{ \delta x_3} \frac{ \delta x_3}{\delta z_2} \frac{\delta z_2}{ \delta x_2} \frac{ \delta x_2}{\delta z_1} \frac{\delta z_1}{\delta b_1} \\ = \frac{ \delta C}{\delta y_4} (\sigma '(z_4) w_4)( \sigma'(z_3) w_3)( \sigma ' (z_2) w_2)( \sigma ' (z_1))

而sigmoid函數的導數公式為:  S'(x ) = \frac{e^{-x}}{(1+ e^{-x})^2} = S(x)(1- S(x))  它的圖形曲線為:

 

 

由上可見,sigmoid函數的導數 \sigma'(x) 的最大值為 \frac{1}{4} ,通常我們會將權重初始值 |w| 初始化為為小於1的隨機值,因此我們可以得到 |\sigma '(z_4) w_4| < \frac{1}{4} ,隨着層數的增多,那么求導結果\frac{\delta C}{\delta b_1} 越小,這也就導致了梯度消失問題。

那么如果我們設置初始權重 |w| 較大,那么會有 |\sigma '(z_4) w_4| > 1  ,造成梯度太大(也就是下降的步伐太大),這也是造成梯度爆炸的原因。

總之,無論是梯度消失還是梯度爆炸,都是源於網絡結構太深,造成網絡權重不穩定,從本質上來講是因為梯度反向傳播中的連乘效應。

4. RNN中的梯度消失,爆炸問題

參考:RNN梯度消失和爆炸的原因, 這篇文章是我看到講的最清楚的了,在這里添加一些我的思考, 若侵立刪。

我們給定一個三個時間的RNN單元,如下:

我們假設最左端的輸入 S_0 為給定值, 且神經元中沒有激活函數(便於分析), 則前向過程如下:

S_1 = W_xX_1 + W_sS_0 + b_1 \qquad \qquad \qquad O_1 = W_oS_1 + b_2 \\ S_2 = W_xX_2 + W_sS_1 + b_1 \qquad \qquad \qquad O_2 = W_oS_2 + b_2 \\ S_3 = W_xX_3 + W_sS_2 + b_1 \qquad \qquad \qquad O_3 = W_oS_3 + b_2 \\

在 t=3 時刻, 損失函數為 L_3 = \frac{1}{2}(Y_3 - O_3)^2 ,那么如果我們要訓練RNN時, 實際上就是是對 W_x, W_s, W_o,b_1,b_2 求偏導, 並不斷調整它們以使得 L_3 盡可能達到最小(參見反向傳播算法與梯度下降算法)。

那么我們得到以下公式:

\frac{\delta L_3}{\delta W_0} = \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta W_0} \\ \frac{\delta L_3}{\delta W_x} = \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta W_x} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta W_x} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta S_1}\frac{\delta S_1}{\delta W_x} \\ \frac{\delta L_3}{\delta W_s} = \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta W_s} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta W_s} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta S_1}\frac{\delta S_1}{\delta W_s} \\

將上述偏導公式與第三節中的公式比較,我們發現, 隨着神經網絡層數的加深對 W_0 而言並沒有什么影響, 而對 W_x, W_s 會隨着時間序列的拉長而產生梯度消失和梯度爆炸問題。

根據上述分析整理一下公式可得, 對於任意時刻t對 W_x, W_s 求偏導的公式為:

\frac{\delta L_t}{\delta W_x } = \sum_{k=0}^t \frac{\delta L_t}{\delta O_t} \frac{\delta O_t}{\delta S_t}( \prod_{j=k+1}^t \frac{\delta S_j}{\delta S_{j-1}} ) \frac{ \delta S_k }{\delta W_x} \\ \frac{\delta L_t}{\delta W_s } = \sum_{k=0}^t \frac{\delta L_t}{\delta O_t} \frac{\delta O_t}{\delta S_t}( \prod_{j=k+1}^t \frac{\delta S_j}{\delta S_{j-1}} ) \frac{ \delta S_k }{\delta W_s}

我們發現, 導致梯度消失和爆炸的就在於 \prod_{j=k+1}^t \frac{\delta S_j}{\delta S_{j-1}} , 而加上激活函數后的S的表達式為:

S_j = tanh(W_xX_j + W_sS_{j-1} + b_1)

那么則有:

\prod_{j=k+1}^t \frac{\delta S_j}{\delta S_{j-1}} = \prod_{j=k+1}^t tanh' W_s

而在這個公式中, tanh的導數總是小於1 的, 如果 W_s 也是一個大於0小於1的值, 那么隨着t的增大, 上述公式的值越來越趨近於0, 這就導致了梯度消失問題。 那么如果 W_s 很大, 上述公式會越來越趨向於無窮, 這就產生了梯度爆炸。

5. 為什么LSTM能解決梯度問題?

在閱讀此篇文章之前,確保自己對LSTM的三門機制有一定了解, 參見:LSTM:RNN最常用的變體

從上述中我們知道, RNN產生梯度消失與梯度爆炸的原因就在於 \prod_{j=k+1}^t \frac{\delta S_j}{\delta S_{j-1}}  , 如果我們能夠將這一坨東西去掉, 我們的不就解決掉梯度問題了嗎。 LSTM通過門機制來解決了這個問題。

我們先從LSTM的三個門公式出發:

  • 遺忘門: f_t = \sigma( W_f \cdot [h_{t-1}, x_t] + b_f)
  • 輸入門: i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)
  • 輸出門: o_t = \sigma(W_o \cdot [h_{t-1}, x_t ] + b_0 )
  • 當前單元狀態 c_t : c_t = f_t \circ c_{t-1} + i_t \circ tanh(W_c \cdot [h_{t-1}, x_t] + b_c )
  • 當前時刻的隱層輸出: h_t = o_t \circ tanh(c_t)

我們注意到, 首先三個門的激活函數是sigmoid, 這也就意味着這三個門的輸出要么接近於0 , 要么接近於1。這就使得 \frac{\delta c_t}{\delta c_{t-1}} = f_t, \frac{\delta h_t}{\delta h_{t-1}} = o_t 是非0即1的,當門為1時, 梯度能夠很好的在LSTM中傳遞,很大程度上減輕了梯度消失發生的概率, 當門為0時,說明上一時刻的信息對當前時刻沒有影響, 我們也就沒有必要傳遞梯度回去來更新參數了。所以, 這就是為什么通過門機制就能夠解決梯度的原因: 使得單元間的傳遞 \frac{\delta S_j}{\delta S_{j-1}} 為0 或 1。

 

https://blog.csdn.net/hx14301009/article/details/80401227  里提到還有CEC。

 

則梯度會隨着反向傳播層數的增加而呈指數增長,導致梯度爆炸。

如果對於所有的 有

則在經過多層的傳播后,梯度會趨向於0,導致梯度彌散(消失)。

Sepp Hochreiter 和 Jürgen Schmidhuber 在他們提出 Long Short Term Memory 的文章里講到,為了避免梯度彌散和梯度爆炸,一個 naive 的方法就是強行讓 error flow 變成一個常數:


就是RNN里自己到自己的連接。他們把這樣得到的模塊叫做CEC(constant error carrousel),很顯然由於上面那個約束條件的存在,這個CEC模塊是線性的。這就是LSTM處理梯度消失的問題的動機。

通俗地講:RNN中,每個記憶單元h_t-1都會乘上一個W和激活函數的導數,這種連乘使得記憶衰減的很快,而LSTM是通過記憶和當前輸入"相加",使得之前的記憶會繼續存在而不是受到乘法的影響而部分“消失”,因此不會衰減。但是這種naive的做法太直白了,實際上就是個線性模型,在學習效果上不夠好,因此LSTM引入了那3個門:

作者說所有“gradient based”的方法在權重更新都會遇到兩個問題:

input weight conflict 和 output weight conflict

大意就是對於神經元的權重 ,不同的數據 所帶來的更新是不同的,這樣可能會引起沖突(比如有些輸入想讓權重變小,有些想讓它變大)。網絡可能需要選擇性地“忘記”某些輸入,以及“屏蔽”某些輸出以免影響下一層的權重更新。為了解決這些問題就提出了“門”。

舉個例子:在英文短語中,主語對謂語的狀態具有影響,而如果之前同時出現過第一人稱和第三人稱,那么這兩個記憶對當前謂語就會有不同的影響,為了避免這種矛盾,我們希望網絡可以忘記一些記憶來屏蔽某些不需要的影響。

因為LSTM對記憶的操作是相加的,線性的,使得不同時序的記憶對當前的影響相同,為了讓不同時序的記憶對當前影響變得可控,LSTM引入了輸入門和輸出門,之后又有人對LSTM進行了擴展,引入了遺忘門。

總結一下:LSTM把原本RNN的單元改造成一個叫做CEC的部件,這個部件保證了誤差將以常數的形式在網絡中流動 ,並在此基礎上添加輸入門和輸出門使得模型變成非線性的,並可以調整不同時序的輸出對模型后續動作的影響。



 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM