一、梯度消失、梯度爆炸產生的原因
說白了,對於1.1 1.2,其實就是矩陣的高次冪導致的。在多層神經網絡中,影響因素主要是權值和激活函數的偏導數。
1.1 前饋網絡
假設存在一個網絡結構如圖:
其表達式為:
若要對於w1求梯度,根據鏈式求導法則,得到的解為:
通常,若使用的激活函數為sigmoid函數,其導數:
這樣可以看到,如果我們使用標准化初始w,那么各個層次的相乘都是0-1之間的小數,而激活函數f的導數也是0-1之間的數,其連乘后,結果會變的很小,導致梯度消失。若我們初始化的w是很大的數,w大到乘以激活函數的導數都大於1,那么連乘后,可能會導致求導的結果很大,形成梯度爆炸。
當然,若對於b求偏導的話,其實也是一個道理:
推出:
1.2 RNN
對於RNN的梯度下降方法,是一種基於時間的反向求導算法(BPTT),RNN的表達式:
通常我們會將一個完整的句子序列視作一個訓練樣本,因此總誤差即為各時間步(單詞)的誤差之和。
而RNN還存在一個權值共享的問題,即這幾個w都是一個,假設,存在一個反復與w相乘的路徑,t步后,得到向量:
若特征值大於1,則會出現梯度爆炸,若特征值小於1,則會出現梯度消失。因此在一定程度上,RNN對比BP更容易出現梯度問題。主要是因為RNN處理時間步長一旦長了,W求導的路徑也變的很長,即使RNN深度不大,也會比較深的BP神經網絡的鏈式求導的過程長很大;另外,對於共享權值w,不同的wi相乘也在一定程度上可以避免梯度問題。
1.3 懸崖和梯度爆炸
對於目標函數,通常存在梯度變化很大的一個“懸崖”,在此處求梯度,很容易導致求解不穩定的梯度爆炸現象。
三、梯度消失和梯度爆炸哪種經常出現
事實上,梯度消失更容易出現,因為對於激活函數的求導:
可以看到,當w越大,其wx+b很可能變的很大,而根據上面sigmoid函數導數的圖像可以看到,wx+b越大,導數的值也會變的很小。因此,若要出現梯度爆炸,其w既要大還要保證激活函數的導數不要太小。
二、如何解決梯度消失、梯度爆炸
1、對於RNN,可以通過梯度截斷,避免梯度爆炸
2、可以通過添加正則項,避免梯度爆炸
3、使用LSTM等自循環和門控制機制,避免梯度消失,參考:https://www.cnblogs.com/pinking/p/9362966.html
4、優化激活函數,譬如將sigmold改為relu,避免梯度消失