三步理解--門控循環單元(GRU),TensorFlow實現


1. 什么是GRU

在循環神經⽹絡中的梯度計算⽅法中,我們發現,當時間步數較⼤或者時間步較小時,循環神經⽹絡的梯度較容易出現衰減或爆炸。雖然裁剪梯度可以應對梯度爆炸,但⽆法解決梯度衰減的問題。通常由於這個原因,循環神經⽹絡在實際中較難捕捉時間序列中時間步距離較⼤的依賴關系。

門控循環神經⽹絡(gated recurrent neural network)的提出,正是為了更好地捕捉時間序列中時間步距離較⼤的依賴關系。它通過可以學習的⻔來控制信息的流動。其中,門控循環單元(gatedrecurrent unit,GRU)是⼀種常⽤的門控循環神經⽹絡。

2. ⻔控循環單元

2.1 重置門和更新門

GRU它引⼊了重置⻔(reset gate)和更新⻔(update gate)的概念,從而修改了循環神經⽹絡中隱藏狀態的計算⽅式。

門控循環單元中的重置⻔和更新⻔的輸⼊均為當前時間步輸⼊ \(X_t\) 與上⼀時間步隱藏狀態\(H_{t-1}\),輸出由激活函數為sigmoid函數的全連接層計算得到。 如下圖所示:

具體來說,假設隱藏單元個數為 h,給定時間步 t 的小批量輸⼊ \(X_t\in_{}\mathbb{R}^{n*d}\)(樣本數為n,輸⼊個數為d)和上⼀時間步隱藏狀態 \(H_{t-1}\in_{}\mathbb{R}^{n*h}\)。重置⻔ \(H_t\in_{}\mathbb{R}^{n*h}\) 和更新⻔ \(Z_t\in_{}\mathbb{R}^{n*h}\) 的計算如下:

\[R_t=\sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r) \]

\[Z_t=\sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z) \]

sigmoid函數可以將元素的值變換到0和1之間。因此,重置⻔ \(R_t\) 和更新⻔ \(Z_t\) 中每個元素的值域都是[0, 1]。

2.2 候選隱藏狀態

接下來,⻔控循環單元將計算候選隱藏狀態來輔助稍后的隱藏狀態計算。我們將當前時間步重置⻔的輸出與上⼀時間步隱藏狀態做按元素乘法(符號為)。如果重置⻔中元素值接近0,那么意味着重置對應隱藏狀態元素為0,即丟棄上⼀時間步的隱藏狀態。如果元素值接近1,那么表⽰保留上⼀時間步的隱藏狀態。然后,將按元素乘法的結果與當前時間步的輸⼊連結,再通過含激活函數tanh的全連接層計算出候選隱藏狀態,其所有元素的值域為[-1,1]。

具體來說,時間步 t 的候選隱藏狀態 \(\tilde{H}\in_{}\mathbb{R}^{n*h}\) 的計算為:

\[\tilde{H}_t=tanh(X_tW_{xh}+(R_t⊙H_{t-1})W_{hh}+b_h) \]

從上⾯這個公式可以看出,重置⻔控制了上⼀時間步的隱藏狀態如何流⼊當前時間步的候選隱藏狀態。而上⼀時間步的隱藏狀態可能包含了時間序列截⾄上⼀時間步的全部歷史信息。因此,重置⻔可以⽤來丟棄與預測⽆關的歷史信息。

2.3 隱藏狀態

最后,時間步t的隱藏狀態 \(H_t\in_{}\mathbb{R}^{n*h}\) 的計算使⽤當前時間步的更新⻔\(Z_t\)來對上⼀時間步的隱藏狀態 \(H_{t-1}\) 和當前時間步的候選隱藏狀態 \(\tilde{H}_t\) 做組合:

值得注意的是,更新⻔可以控制隱藏狀態應該如何被包含當前時間步信息的候選隱藏狀態所更新,如上圖所⽰。假設更新⻔在時間步 \(t^{′}到t(t^{′}<t)\) 之間⼀直近似1。那么,在時間步 \(t^{′}到t\) 間的輸⼊信息⼏乎沒有流⼊時間步 t 的隱藏狀態\(H_t\)實際上,這可以看作是較早時刻的隱藏狀態 \(H_{t^{′}-1}\) 直通過時間保存並傳遞⾄當前時間步 t。這個設計可以應對循環神經⽹絡中的梯度衰減問題,並更好地捕捉時間序列中時間步距離較⼤的依賴關系。

我們對⻔控循環單元的設計稍作總結:

  • 重置⻔有助於捕捉時間序列⾥短期的依賴關系;
  • 更新⻔有助於捕捉時間序列⾥⻓期的依賴關系。

3. 代碼實現GRU

MNIST--GRU實現

機器學習通俗易懂系列文章

3.png

4. 參考文獻

《動手學--深度學習》


作者:@mantchs

GitHub:https://github.com/NLP-LOVE/ML-NLP

歡迎大家加入討論!共同完善此項目!群號:【541954936】NLP面試學習群


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM