Recurrent Neural Network系列3--理解RNN的BPTT算法和梯度消失


作者:zhbzz2007 出處:http://www.cnblogs.com/zhbzz2007 歡迎轉載,也請保留這段聲明。謝謝!

這是RNN教程的第三部分。

在前面的教程中,我們從頭實現了一個循環神經網絡,但是並沒有涉及隨時間反向傳播(BPTT)算法如何計算梯度的細節。在這部分,我們將會簡要介紹BPTT並解釋它和傳統的反向傳播有何區別。我們也會嘗試着理解梯度消失問題,這也是LSTM和GRU(目前NLP及其它領域中最為流行和有用的模型)得以發展的原因。梯度消失問題最早是由 Sepp Hochreiter 在1991年發現,最近由於深度框架的廣泛應用再次獲得很多關注。

為了能夠完全理解這部分,我建議你熟悉偏微分和基本的反向傳播工作原理。如果你不熟悉這些內容,你需要看這些教程 CS231n Convolutional Neural Networks for Visual RecognitionCalculus on Computational Graphs: BackpropagationHow the backpropagation algorithm works ,這些教程的難度依次增加 。

1 BPTT

讓我們快速回憶一下循環神經網絡中的一些基本公式。定義中略微有些變化,我們將 \(o\) 修改為 \(\hat{y}\) 。這是為了與一些參考文獻保持一致。

\(s_{t} = tanh(U x_{t} + W s_{t-1})\)

\(\hat{y_{t}} = softmax(V s_{t})\)

我們定義損失或者誤差為互熵損失,如下所示,

\(E_{t}(y_{t},\hat{y_{t}}) = -y_{t}log(\hat{y_{t}})\)

\(E_{t}(y,\hat{y}) = \sum_{t}E_{t}(y_{t},\hat{y_{t}})=-\sum_{t}y_{t}log(\hat{y_{t}})\)

在這里, \(y_{t}\) 是時刻 t 上正確的詞, \(\hat{y_{t}}\) 是預測出來的詞。我們通常將一整個序列(一個句子)作為一個訓練實例,所以總的誤差就是各個時刻(詞)的誤差之和。

請牢記,我們的目標是計算誤差關於參數U、V和W的梯度,然后使用梯度下降法學習出好的參數。正如我們將誤差相加,我們也將一個訓練實例在每時刻的梯度相加: \(\frac{\partial E}{\partial W} = \sum_{t}\frac{\partial E_{t}}{\partial W}\)

為了計算這些梯度,我們需要使用微分的鏈式法則。當從誤差開始向后時,這就是 反向傳播 。在本文后續的部分,我們將會以 \(E_{3}\) 為例,僅僅是為了使用具體的數字。

\(\frac{\partial E_{3}}{\partial V} = \frac{\partial E_{3}}{\partial \hat{y_{3}}} \frac{\partial \hat{y_{3}}}{\partial V} =\frac{\partial E_{3}}{\partial \hat{y_{3}}} \frac{\partial \hat{y_{3}}}{\partial z_{3}} \frac{\partial z_{3}}{\partial V}=(\hat{y_{3}} - y_{3}) \otimes s_{3}\)

在上述定義中,我們定義 \(z_{3} = V s_{3}\)\(\otimes\) 是兩個向量的外積。如果你暫時跟不上,不要擔心,我忽略了其中幾步,你也可以嘗試着自己計算這些梯度。我想要強調的是 \(\frac{\partial E_{3}}{\partial V}\) 僅僅依賴當前時刻的值,如 \(\hat{y_{3}}\)\(y_{3}\)\(s_{3}\) 。如果你已經有這些值,計算變量V的梯度就是一個簡單的矩陣相乘。

計算 \(\frac{\partial E_{3}}{\partial W}\) 卻有所不同,對於U也是。為了了解原因,我們寫出鏈式法則,正如上面所示,

\(\frac{\partial E_{3}}{\partial W}=\frac{\partial E_{3}}{\partial \hat{y_{3}}} \frac{\partial \hat{y_{3}}}{\partial s_{3}} \frac{\partial s_{3}}{\partial W}\)

其中, \(s_{3} = tanh(U x_{t} + W s_{2})\) (應該為 \(s_{3} = tanh(U x_{3} + W s_{2})\) )依賴於 \(s_{2}\) ,而 \(s_{2}\) 依賴於 W和 \(s_{1}\) 。所以如果我們對 W 求導數,我們不能簡單的將 \(s_{2}\) 視為一個常量。我們需要再次應用鏈式法則,我們真正想要的如下所示:

\(\frac{\partial E_{3}}{\partial W}=\sum_{k=0}^{3}\frac{\partial E_{3}}{\partial \hat{y_{3}}} \frac{\partial \hat{y_{3}}}{\partial s_{3}} \frac{\partial s_{3}}{\partial s_{k}} \frac{\partial s_{k}}{\partial W}\)

我們將每時刻對梯度的貢獻相加。也就是說,由於 W 在每時刻都用在我們所關心的輸出上,我們需要從時刻 t = 3 通過網絡的所有路徑到時刻 t = 0 來反向傳播梯度:

請留意,這與我們在深度前饋神經網絡中使用的標准反向傳播算法完全相同。主要的差異就是我們將每時刻 W 的梯度相加。在傳統的神經網絡中,我們在層之間並沒有共享參數,所以我們不需要相加。但是我認為,BPTT就是標准反向傳播算法在展開的循環神經網絡上一個花哨的名稱。正如在反向傳播算法中,你可以定義一個反向傳播的 delta 向量,例如 \(\delta_{2}^{(3)} = \frac{\partial E_{3}}{\partial z_{2}} = \frac{\partial E_{3}}{\partial s_{3}} \frac{\partial s_{3}}{\partial s_{2}} \frac{\partial s_{2}}{\partial z_{2}}\) ,其中 \(z_{2} = U x_{2} + W s_{1}\) , 然后應用相同的方程。

一個朴素的BPTT實現,代碼如下,

def bptt(self, x, y):
    T = len(y)
    # Perform forward propagation
    o, s = self.forward_propagation(x)
    # We accumulate the gradients in these variables
    dLdU = np.zeros(self.U.shape)
    dLdV = np.zeros(self.V.shape)
    dLdW = np.zeros(self.W.shape)
    delta_o = o
    delta_o[np.arange(len(y)), y] -= 1.
    # For each output backwards...
    for t in np.arange(T)[::-1]:
        dLdV += np.outer(delta_o[t], s[t].T)
        # Initial delta calculation: dL/dz
        delta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2))
        # Backpropagation through time (for at most self.bptt_truncate steps)
        for bptt_step in np.arange(max(0, t-self.bptt_truncate), t+1)[::-1]:
            # print "Backpropagation step t=%d bptt step=%d " % (t, bptt_step)
            # Add to gradients at each previous step
            dLdW += np.outer(delta_t, s[bptt_step-1])              
            dLdU[:,x[bptt_step]] += delta_t
            # Update delta for next step dL/dz at t-1
            delta_t = self.W.T.dot(delta_t) * (1 - s[bptt_step-1] ** 2)
    return [dLdU, dLdV, dLdW]

這應該會給你一個印象:為什么標准的循環神經網絡很難訓練?序列(句子)可以很長,可能20個詞或者更多,因此你需要反向傳播很多層。實際上,許多人會在反向傳播數步之后進行截斷。

2 梯度消失

在前面的博文 Recurrent Neural Network系列1--RNN(循環神經網絡)概述 中,我已經提到循環神經網絡很難學習到長期的依賴 -- 在相隔數步的詞之間的影響。這就會導致一些問題,因為英文句子通常被一些不是很近的詞所決定,例如:“The man who wore a wig on his head went inside” 。這個句子是關於一個人走進屋里,不是關於假發的。對於普通的循環神經網絡,不太可能捕獲這些信息。為了理解為什么,讓我們仔細分析一下上面推導出來的梯度:

\(\frac{\partial E_{3}}{\partial W}=\sum_{k=0}^{3}\frac{\partial E_{3}}{\partial \hat{y_{3}}} \frac{\partial \hat{y_{3}}}{\partial s_{3}} \frac{\partial s_{3}}{\partial s_{k}} \frac{\partial s_{k}}{\partial W}\)

請注意, \(\frac{\partial s_{3}}{\partial s_{k}}\) 本身就是一個鏈式法則。例如, \(\frac{\partial s_{3}}{\partial s_{1}} = \frac{\partial s_{3}}{\partial s_{2}} \frac{\partial s_{2}}{\partial s_{1}}\) 。也要注意,我們是在一個向量上對向量函數求導,結果會是一個矩陣(稱之為 雅克比矩陣 ),所有的元素都是對應的導數。我可以將上述的梯度重寫為:

\(\frac{\partial E_{3}}{\partial W}=\sum_{k=0}^{3}\frac{\partial E_{3}}{\partial \hat{y_{3}}} \frac{\partial \hat{y_{3}}}{\partial s_{3}} (\prod_{j = k+1}^{3} \frac{\partial s_{j}}{\partial s_{j-1}}) \frac{\partial s_{k}}{\partial W}\)

上述雅克比矩陣中的2范數(你可以認為是絕對值)上限是1(具體參考這篇 On the difficulty of training recurrent neural networks)。tanh(或者sigmoid)激活函數將所有的值映射到-1到1這個區間,導數的范圍在0到1這個區間(sigmoid是0到 \(\frac{1}{4}\) 這個區間),如下圖所示:

你可以看到tanh和sigmoid函數在兩端導數均為0。它們逐漸成為一條直線,當這個現象發生時,我們就說相應的神經元已經飽和了。它們的梯度為0,驅動前一層的其它梯度也趨向於0。因此,矩陣中有小值,並且經過矩陣相乘(t - k次),梯度值快速的以指數形式收縮,最終在幾個時刻之后完全消失。較遠的時刻貢獻的梯度變為0,這些時刻的狀態不會對你的學習有所貢獻:你最終以無法學習到長期依賴而結束。梯度消失不僅僅出現在循環神經網絡中。它們也出現深度前饋神經網絡中。它僅僅是循環神經網絡趨向於很深(在我們這個例子中,深度與句子長度一樣),這將會導致很多問題。

依賴於我們的激活函數和網絡參數,如果雅克比矩陣的值非常大,我們沒有出現梯度消失,但是卻可能出現梯度爆炸。這就是梯度爆炸問題。梯度消失問題比梯度爆炸問題受到更多的關注,主要有兩個原因:1)梯度爆炸很明顯,你的梯度將會變成Nan(不是一個數字),你的程序將會掛掉;2)在預定義閾值處將梯度截斷(具體參考這篇 On the difficulty of training recurrent neural networks)是一種簡單有效的方法去解決梯度爆炸問題。梯度消失問題更加復雜是因為它不明顯,如論是當它們發生或者如何處理它們時。

幸運的是,目前已經有了一些緩解梯度消失問題的方法。對矩陣 W 合理的初始化可以減少梯度消失的影響。也可以加入正則化項。一個更好的方案是使用 ReLU而不是tanh或者sigmoid激活函數。ReLU函數的導數是個常量,要么是0,要么是1,所以它不太可能出現梯度消失。更加流行的方法是使用長短時記憶(LSTM)或者門控循環單元(GRU)架構。LSTM是在 1997年提出,在NLP領域可能是目前最為流行的模型。GRU是在2014年提出,是LSTM的簡化版。這些循環神經網絡的設計都是為了處理梯度消失和有效學習長期依賴。我們將會在后面的博文中介紹。

3 Reference

wiki-Backpropagation through time

BPTT算法推導(需要注意此文中W和U與本文的W和U是相反的)

A Beginner’s Guide to Recurrent Networks and LSTMs

Backpropagation Through Time (BPTT)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM