循環神經網絡(RNN)模型與前向反向傳播算法


    在前面我們講到了DNN,以及DNN的特例CNN的模型和前向反向傳播算法,這些算法都是前向反饋的,模型的輸出和模型本身沒有關聯關系。今天我們就討論另一類輸出和模型間有反饋的神經網絡:循環神經網絡(Recurrent Neural Networks ,以下簡稱RNN),它廣泛的用於自然語言處理中的語音識別,手寫書別以及機器翻譯等領域。

1. RNN概述

    在前面講到的DNN和CNN中,訓練樣本的輸入和輸出是比較的確定的。但是有一類問題DNN和CNN不好解決,就是訓練樣本輸入是連續的序列,且序列的長短不一,比如基於時間的序列:一段段連續的語音,一段段連續的手寫文字。這些序列比較長,且長度不一,比較難直接的拆分成一個個獨立的樣本來通過DNN/CNN進行訓練。

    而對於這類問題,RNN則比較的擅長。那么RNN是怎么做到的呢?RNN假設我們的樣本是基於序列的。比如是從序列索引1到序列索引$\tau$的。對於這其中的任意序列索引號$t$,它對應的輸入是對應的樣本序列中的$x^{(t)}$。而模型在序列索引號$t$位置的隱藏狀態$h^{(t)}$,則由$x^{(t)}$和在$t-1$位置的隱藏狀態$h^{(t-1)}$共同決定。在任意序列索引號$t$,我們也有對應的模型預測輸出$o^{(t)}$。通過預測輸出$o^{(t)}$和訓練序列真實輸出$y^{(t)}$,以及損失函數$L^{(t)}$,我們就可以用DNN類似的方法來訓練模型,接着用來預測測試序列中的一些位置的輸出。

    下面我們來看看RNN的模型。

2. RNN模型

    RNN模型有比較多的變種,這里介紹最主流的RNN模型結構如下:

    上圖中左邊是RNN模型沒有按時間展開的圖,如果按時間序列展開,則是上圖中的右邊部分。我們重點觀察右邊部分的圖。

    這幅圖描述了在序列索引號$t$附近RNN的模型。其中:

    1)$x^{(t)}$代表在序列索引號$t$時訓練樣本的輸入。同樣的,$x^{(t-1)}$和$x^{(t+1)}$代表在序列索引號$t-1$和$t+1$時訓練樣本的輸入。

    2)$h^{(t)}$代表在序列索引號$t$時模型的隱藏狀態。$h^{(t)}$由$x^{(t)}$和$h^{(t-1)}$共同決定。

    3)$o^{(t)}$代表在序列索引號$t$時模型的輸出。$o^{(t)}$只由模型當前的隱藏狀態$h^{(t)}$決定。

    4)$L^{(t)}$代表在序列索引號$t$時模型的損失函數。

    5)$y^{(t)}$代表在序列索引號$t$時訓練樣本序列的真實輸出。

    6)$U,W,V$這三個矩陣是我們的模型的線性關系參數,它在整個RNN網絡中是共享的,這點和DNN很不相同。 也正因為是共享了,它體現了RNN的模型的“循環反饋”的思想。  

3. RNN前向傳播算法

    有了上面的模型,RNN的前向傳播算法就很容易得到了。

    對於任意一個序列索引號$t$,我們隱藏狀態$h^{(t)}$由$x^{(t)}$和$h^{(t-1)}$得到:$$h^{(t)} = \sigma(z^{(t)}) = \sigma(Ux^{(t)} + Wh^{(t-1)} +b )$$

    其中$\sigma$為RNN的激活函數,一般為$tanh$, $b$為線性關系的偏倚。

    序列索引號$t$時模型的輸出$o^{(t)}$的表達式比較簡單:$$o^{(t)} = Vh^{(t)} +c $$

    在最終在序列索引號$t$時我們的預測輸出為:$$\hat{y}^{(t)} = \sigma(o^{(t)})$$

    通常由於RNN是識別類的分類模型,所以上面這個激活函數一般是softmax。

    通過損失函數$L^{(t)}$,比如對數似然損失函數,我們可以量化模型在當前位置的損失,即$\hat{y}^{(t)}$和$y^{(t)}$的差距。

4. RNN反向傳播算法推導

    有了RNN前向傳播算法的基礎,就容易推導出RNN反向傳播算法的流程了。RNN反向傳播算法的思路和DNN是一樣的,即通過梯度下降法一輪輪的迭代,得到合適的RNN模型參數$U,W,V,b,c$。由於我們是基於時間反向傳播,所以RNN的反向傳播有時也叫做BPTT(back-propagation through time)。當然這里的BPTT和DNN也有很大的不同點,即這里所有的$U,W,V,b,c$在序列的各個位置是共享的,反向傳播時我們更新的是相同的參數。

    為了簡化描述,這里的損失函數我們為交叉熵損失函數,輸出的激活函數為softmax函數,隱藏層的激活函數為tanh函數。

    對於RNN,由於我們在序列的每個位置都有損失函數,因此最終的損失$L$為:$$L = \sum\limits_{t=1}^{\tau}L^{(t)}$$

    其中$V,c,$的梯度計算是比較簡單的:$$\frac{\partial L}{\partial c} = \sum\limits_{t=1}^{\tau}\frac{\partial L^{(t)}}{\partial c}  = \sum\limits_{t=1}^{\tau}\hat{y}^{(t)} - y^{(t)}$$$$\frac{\partial L}{\partial V} =\sum\limits_{t=1}^{\tau}\frac{\partial L^{(t)}}{\partial V}  = \sum\limits_{t=1}^{\tau}(\hat{y}^{(t)} - y^{(t)}) (h^{(t)})^T$$

    但是$W,U,b$的梯度計算就比較的復雜了。從RNN的模型可以看出,在反向傳播時,在在某一序列位置t的梯度損失由當前位置的輸出對應的梯度損失和序列索引位置$t+1$時的梯度損失兩部分共同決定。對於$W$在某一序列位置t的梯度損失需要反向傳播一步步的計算。我們定義序列索引$t$位置的隱藏狀態的梯度為:$$\delta^{(t)} = \frac{\partial L}{\partial h^{(t)}}$$

    這樣我們可以像DNN一樣從$\delta^{(t+1)} $遞推$\delta^{(t)}$ 。$$\delta^{(t)} =(\frac{\partial o^{(t)}}{\partial h^{(t)}} )^T\frac{\partial L}{\partial o^{(t)}} + (\frac{\partial h^{(t+1)}}{\partial h^{(t)}})^T\frac{\partial L}{\partial h^{(t+1)}} = V^T(\hat{y}^{(t)} - y^{(t)}) + W^Tdiag(1-(h^{(t+1)})^2)\delta^{(t+1)}$$

    對於$\delta^{(\tau)} $,由於它的后面沒有其他的序列索引了,因此有:$$\delta^{(\tau)} =( \frac{\partial o^{(\tau)}}{\partial h^{(\tau)}})^T\frac{\partial L}{\partial o^{(\tau)}} = V^T(\hat{y}^{(\tau)} - y^{(\tau)})$$

    有了$\delta^{(t)} $,計算$W,U,b$就容易了,這里給出$W,U,b$的梯度計算表達式:$$\frac{\partial L}{\partial W} = \sum\limits_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}(h^{(t-1)})^T$$$$\frac{\partial L}{\partial b}= \sum\limits_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}$$$$\frac{\partial L}{\partial U} =\sum\limits_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}(x^{(t)})^T$$

    除了梯度表達式不同,RNN的反向傳播算法和DNN區別不大,因此這里就不再重復總結了。

5. RNN小結

    上面總結了通用的RNN模型和前向反向傳播算法。當然,有些RNN模型會有些不同,自然前向反向傳播的公式會有些不一樣,但是原理基本類似。

    RNN雖然理論上可以很漂亮的解決序列數據的訓練,但是它也像DNN一樣有梯度消失時的問題,當序列很長的時候問題尤其嚴重。因此,上面的RNN模型一般不能直接用於應用領域。在語音識別,手寫書別以及機器翻譯等NLP領域實際應用比較廣泛的是基於RNN模型的一個特例LSTM,下一篇我們就來討論LSTM模型。

 

(歡迎轉載,轉載請注明出處。歡迎溝通交流:liujianping-ok@163.com) 

參考資料:

1) Neural Networks and Deep Learning by By Michael Nielsen

2) Deep Learning, book by Ian Goodfellow, Yoshua Bengio, and Aaron Courville

3) UFLDL Tutorial

4)CS231n Convolutional Neural Networks for Visual Recognition, Stanford


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM