LSTM改善RNN梯度彌散和梯度爆炸問題


我們給定一個三個時間的RNN單元,如下:

我們假設最左端的輸入 S_0 為給定值, 且神經元中沒有激活函數(便於分析), 則前向過程如下:

S_1 = W_xX_1 + W_sS_0 + b_1 \qquad \qquad \qquad O_1 = W_oS_1 + b_2 \\ S_2 = W_xX_2 + W_sS_1 + b_1 \qquad \qquad \qquad O_2 = W_oS_2 + b_2 \\ S_3 = W_xX_3 + W_sS_2 + b_1 \qquad \qquad \qquad O_3 = W_oS_3 + b_2 \\

在 t=3 時刻, 損失函數為 L_3 = \frac{1}{2}(Y_3 - O_3)^2 ,那么如果我們要訓練RNN時, 實際上就是是對 W_x, W_s, W_o,b_1,b_2 求偏導, 並不斷調整它們以使得 L_3 盡可能達到最小(參見反向傳播算法與梯度下降算法)。

那么我們得到以下公式:

\frac{\delta L_3}{\delta W_0} = \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta W_0} \\ \frac{\delta L_3}{\delta W_x} = \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta W_x} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta W_x} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta S_1}\frac{\delta S_1}{\delta W_x} \\ \frac{\delta L_3}{\delta W_s} = \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta W_s} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta W_s} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta S_1}\frac{\delta S_1}{\delta W_s} \\

將上述偏導公式與第三節中的公式比較,我們發現, 隨着神經網絡層數的加深對 W_0 而言並沒有什么影響, 而對 W_x, W_s 會隨着時間序列的拉長而產生梯度消失和梯度爆炸問題。

根據上述分析整理一下公式可得, 對於任意時刻t對 W_x, W_s 求偏導的公式為:

\frac{\delta L_t}{\delta W_x } = \sum_{k=0}^t \frac{\delta L_t}{\delta O_t} \frac{\delta O_t}{\delta S_t}( \prod_{j=k+1}^t \frac{\delta S_j}{\delta S_{j-1}} ) \frac{ \delta S_k }{\delta W_x} \\ \frac{\delta L_t}{\delta W_s } = \sum_{k=0}^t \frac{\delta L_t}{\delta O_t} \frac{\delta O_t}{\delta S_t}( \prod_{j=k+1}^t \frac{\delta S_j}{\delta S_{j-1}} ) \frac{ \delta S_k }{\delta W_s}

由 以上可知,RNN 中總的梯度是不會消失的。即便梯度越傳越弱,那也只是遠距離的梯度消失,由於近距離的梯度不會消失,所有梯度之和便不會消失。RNN 所謂梯度消失的真正含義是,梯度被近距離梯度主導,導致模型難以學到遠距離的依賴關系。

參考:

https://www.cnblogs.com/bonelee/p/10475453.html

 

https://www.zhihu.com/question/34878706


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM