RNN中的梯度消失/爆炸原因
梯度消失/梯度爆炸是深度學習中老生常談的話題,這篇博客主要是對RNN中的梯度消失/梯度爆炸原因進行公式層面上的直觀理解。
首先,上圖是RNN的網絡結構圖,\((x_1, x_2, x_3, …, )\)是輸入的序列,\(X_t\)表示時間步為\(t\)時的輸入向量。假設我們總共有\(k\)個時間步,用第\(k\)個時間步的輸出\(H_k\)作為輸出(實際上每個時間步都有輸出,這里僅考慮\(H_k\)),用\(E_k\)表示損失。
其中,\(C_{t}=\tanh \left(W_{c} C_{t-1}+W_{x} X_{t}\right)\)
從上式可以看出 \(W_x\)和\(W_c\)其實是差不多的,記\(W=[W_c, W_x]\),那么求偏導可以得到:
\(\begin{aligned} \frac{\partial E_{k}}{\partial W}=& \frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}} \frac{\partial C_{k}}{\partial C_{k-1}} \ldots \frac{\partial C_{2}}{\partial C_{1}} \frac{\partial C_{1}}{\partial W}=\\ & \frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}}\left(\prod_{t=2}^{k} \frac{\partial C_{t}}{\partial C_{t-1}}\right) \frac{\partial C_{1}}{\partial W} \end{aligned}\)
其中的累乘部分為:
\(\begin{aligned} \frac{\partial C_{t}}{\partial c_{t-1}}=& \tanh ^{\prime}\left(W_{c} C_{t-1}+W_{x} X_{t}\right) \cdot \frac{d}{d C_{t-1}}\left[W_{c} C_{t-1}+W_{x} X_{t}\right]=\\ & \tanh ^{\prime}\left(W_{c} C_{t-1}+W_{x} X_{t}\right) \cdot W_{c} \end{aligned}\)
將該式代入上式有:
\(\frac{\partial E_{k}}{\partial W}=\frac{\partial E_{k}}{\partial H_{k}} \frac{\partial H_{k}}{\partial C_{k}}\left(\prod_{t=2}^{k} \tanh ^{\prime}\left(W_{c} C_{t-1}+W_{x} X_{t}\right) \cdot W_{c}\right) \frac{\partial c_{1}}{\partial W}\)
觀察這個式子,和上篇文章中一樣,因為鏈式法則,出現了累乘項,因為tanh的導數 <= 1,所以,當k很大的時候,上式的值是趨向於0的。(<1的數多次相乘),也就是:
\(\Pi_{t=2}^{k} \tanh ^{\prime}\left(W_{c} C_{t-1}+w_{x} X_{t}\right) \cdot W_{c} \rightarrow 0,\) so \(\frac{\partial E_{k}}{\partial W} \rightarrow 0\)
此時,權重更新公式:
\(W \leftarrow W-\alpha \frac{\partial E_{k}}{\partial W} \approx W\)
也就是說,RNN很容易出現梯度消失現象,使得參數更新緩慢,甚至是停止更新。