深度學習(Deep Learning):循環神經網絡一(RNN)


原址:https://blog.csdn.net/fangqingan_java/article/details/53014085

概述

循環神經網絡(RNN-Recurrent Neural Network)是神經網絡家族中的一員,擅長於解決序列化相關問題。包括不限於序列化標注問題、NER、POS、語音識別等。RNN內容比較多,分成三個小節進行介紹,內容包括RNN基礎以及求解算法、LSTM以及變種GRU、RNN相關應用。本節主要介紹

1.RNN基礎知識介紹 
2.RNN模型優化以及存在的問題 
3.RNN模型變種

RNN知識點

RNN提出動機

RNN的提出可以有效解決以下問題:

  1. 長期依賴問題:在語言模型、語音識別中需要根據上下文進行推斷和預測,上下文的獲取可以根據馬爾科夫假設獲取固定上下文。RNN可以通過中間狀態保存上下文信息,作為輸入影響下一時序的預測。
  2. 編碼:可以將可變輸入編碼成固定長度的向量。和CNN相比,能夠保留全局最優特征。

    計算圖展開

    RNN常用以下公式獲取歷史狀態

    ht=f(ht−1,xt;θ)ht=f(ht−1,xt;θ)


    其中h為隱藏層,用於保存上下文信息,f是激活函數。 
    用圖模型可以表達為: 這里寫圖片描述

     

RNN潛在可能的展開方式如下: 
1)通過隱藏層傳遞信息 
這里寫圖片描述

1.該展開形式非常常用,主要包括三層輸入-隱藏層、隱藏層-隱藏層、隱藏層到輸入層。依賴信息通過隱藏層進行傳遞。 
2.參數U、V、W為共享參數

2)輸出節點連接到下一時序序列 
這里寫圖片描述

應用比較局限,上一時序的輸出作為下一時間點的輸入,理論上上一時間點的輸出比較固定,能夠攜帶的信息比較少。

3)只有一個輸出節點 
這里寫圖片描述

只在最后時間點t產生輸出,往往能夠將變成的輸入轉換為固定長度的向量表示。

RNN使用形式

在使用RNN時,主要形式有4中,如下圖所示。 
這里寫圖片描述

1.一對一形式(左一:Many to Many)每一個輸入都有對應的輸出。 
2.多對一形式(左二:Many to one)整個序列只有一個輸出,例如文本分類、情感分析等。 
3. 一對多形式(左三:One to Many)一個輸入產出一個時序序列,常用於seq2seq的解碼階段 
4.多對多形式(左四:Many to Many)不是每一個輸入對應一個輸出,對應到變成的輸出。

RNN數學表達以及優化

RNN前向傳播

對於離散時間的RNN問題可以描述為,輸入序列

(x1,y1),(x2,y2),(x3,y3)......(xT,yT)(x1,y1),(x2,y2),(x3,y3)......(xT,yT)


其中時間參數t表示離散序列,不一定是真實時間點。 
對於多分類問題,目標是最小化釋然函數 

min∑t=1TL(y^(xt),yt)=min−∑tlog p(yt|x1,x2...xt)min∑t=1TL(y^(xt),yt)=min−∑tlog p(yt|x1,x2...xt)

 

根據上面經典的RNN網絡結構,前向傳播過程如下: 
如上圖U、V、W分別表示輸入到隱藏層、隱藏層到輸出以及隱藏到隱藏層的連接參數。 
1. 隱藏層節點權值:at=b+Wht−1+Uxtat=b+Wht−1+Uxt 
2. 隱藏層非線性變換: ht=tanh(at)ht=tanh(at) 
3. 輸出層: ot=c+Vhtot=c+Vht 
4. softmax層: y^t=softmax(ot)y^t=softmax(ot)

RNN優化算法-BPTT

BPTT 是求解RNN問題的一種優化算法,也是基於BP算法改進得到和BP算法比較類似。為直觀上理解通過多分類問題進行簡單推導。 
1. 優化目標,對於多分類問題,BPTT優化目標轉換最小化交叉熵:

min∑tLtLt=−∑kytklogy^tkmin∑tLtLt=−∑kyktlogy^kt

這里假設有k個類 
2. 由於總的損失L為各個時序點的損失和,因此有

∂L∂Lt=1∂L∂Lt=1


3. 對於輸出層中的第i節點有

(∇otL)i=∂L∂oti=∂L∂Lt∂Lt∂oti=y^ti−1i,yt(∇otL)i=∂L∂oit=∂L∂Lt∂Lt∂oit=y^it−1i,yt

最后一步是交叉熵推導結果,步驟省略,了解softmax的都清楚。1i,yt1i,yt表示如果y^t==i則為1,否則為0 
4. 隱藏層節點梯度的計算,分為兩部分,第一部分 t=T。

(∇hTL)i=∑j(∇oTL)j∂oTj∂hTi=∑j(∇oTL)jVij(∇hTL)i=∑j(∇oTL)j∂ojT∂hiT=∑j(∇oTL)jVij

通過向量的方式表達為

(∇hTL)=(∇oTL)∂oT∂hT=(∇oTL)V(∇hTL)=(∇oTL)∂oT∂hT=(∇oTL)V


5.第二部分, 中間節點 t<Tt<T,對於中間節點需要考慮t+1以及以后時間點傳播的誤差,因此計算過程如下。

(∇htL)i=∑j(∇ht+1L)j∂ht+1j∂hti+∑k(∇otL)k∂otk∂hti=隱藏層誤差反饋+輸出層誤差反饋=∑j(∇ht+1L)j∂ht+1j∂at+1j∂at+1j∂hti+∑k(∇otL)kVki=∑j(∇ht+1L)j(1−ht+1j2)Wji+∑k(∇otL)kVki=(∇ht+1L)diag((1−ht+12))Wi+(∇otL)Vi(∇htL)i=∑j(∇ht+1L)j∂hjt+1∂hit+∑k(∇otL)k∂okt∂hit=隱藏層誤差反饋+輸出層誤差反饋=∑j(∇ht+1L)j∂hjt+1∂ajt+1∂ajt+1∂hit+∑k(∇otL)kVki=∑j(∇ht+1L)j(1−hjt+12)Wji+∑k(∇otL)kVki=(∇ht+1L)diag((1−ht+12))Wi+(∇otL)Vi

通過向量表示如下:

(∇htL)=(∇ht+1L)∂ht+1∂ht+(∇otL)∂ot∂ht=(∇ht+1L)diag((1−ht+12))W+(∇otL)V(∇htL)=(∇ht+1L)∂ht+1∂ht+(∇otL)∂ot∂ht=(∇ht+1L)diag((1−ht+12))W+(∇otL)V

其中diag((1−ht+12))diag((1−ht+12))是由1−ht+1i1−hit+1的平方組成的對角矩陣。 
6.根據中間結果的梯度可以推導出其他參數的梯度,結果如下

∇cL∇bL∇VL∇WL∇UL=∑t(∇toL)∂ot∂c=∑t(∇toL)=∑t(∇thL)∂ht∂b=∑t(∇thL)diag((1−ht2))=∑t(∇toL)∂ot∂V=∑t(∇toL)htT=∑t(∇thL)∂ht∂W=∑t(∇thL)diag((1−ht2))ht−1T=∑t(∇thL)∂ht∂U=∑t(∇thL)diag((1−ht2))xtT∇cL=∑t(∇otL)∂ot∂c=∑t(∇otL)∇bL=∑t(∇htL)∂ht∂b=∑t(∇htL)diag((1−ht2))∇VL=∑t(∇otL)∂ot∂V=∑t(∇otL)htT∇WL=∑t(∇htL)∂ht∂W=∑t(∇htL)diag((1−ht2))ht−1T∇UL=∑t(∇htL)∂ht∂U=∑t(∇htL)diag((1−ht2))xtT


7. 到此完成了對所有參數梯度的推導。

 

梯度彌散和爆炸問題

RNN訓練比較困難,主要原因在於隱藏層參數W,無論在前向傳播過程還是在反向傳播過程中都會乘上多次。這樣就會導致1)前向傳播某個小於1的值乘上多次,對輸出影響變小。2)反向傳播時會導致梯度彌散問題,參數優化變得比較困難。 
這里寫圖片描述

可以通過梯度公式也可以看出梯度彌散或者爆炸問題。 
考慮到通用性,激活函數采用f(x)代替,則對隱藏層到隱藏層參數W梯度公式如下: 

∇WL=∑t(∇thL)∂ht∂W=∑t(∇thL)diag(f′(ht))ht−1∇WL=∑t(∇htL)∂ht∂W=∑t(∇htL)diag(f′(ht))ht−1

后面部分可以直接得到,下面詳細分析它的系數(∇thL)(∇htL)

 

1.考慮當t=T,即為最后一個節點時,根據上面的推導有

(∇hTL)=(∇oTL)∂oT∂hT=(∇oTL)V(∇hTL)=(∇oTL)∂oT∂hT=(∇oTL)V


2.當t=T-1時,

(∇hT−1L)=(∇ThL)∂ht+1∂ht=(∇hTL)diag(f′(hT))W(∇hT−1L)=(∇hTL)∂ht+1∂ht=(∇hTL)diag(f′(hT))W

注這里只考慮隱藏層節點對W的誤差傳遞,沒有考慮輸出層。 
3. 當t=T-2時,

(∇hT−2L)=(∇T−1hL)∂hT−1∂hT−2=(∇hTL)diag(f′(hT))Wdiag(f′(hT−1))W=(∇hTL)diag(f′(hT))diag(f′(hT−1))W2(∇hT−2L)=(∇hT−1L)∂hT−1∂hT−2=(∇hTL)diag(f′(hT))Wdiag(f′(hT−1))W=(∇hTL)diag(f′(hT))diag(f′(hT−1))W2


4. 當t=k時

(∇hkL)=(∇ThL)∏j=k+1T∂hj∂hj−1=(∇hTL)∏j=kTdiag(f′(hj))W(∇hkL)=(∇hTL)∏j=k+1T∂hj∂hj−1=(∇hTL)∏j=kTdiag(f′(hj))W


5.此時diag(f′(hj))Wdiag(f′(hj))W的結果是一個對角矩陣,如果其中某個元素大於1,則該值會指數倍放大;否則會以指數倍縮小。 
6.因此可以看出當序列比較長,即模型有長期依賴問題時,就會產生梯度相關問題。一般情況下BPTT對於序列長度在100以內,不會暴露問題。 
7.需要注意的是,如果我們的訓練樣本被人工分為子序列,且長度都較小時,不會產生梯度問題。此時比較依賴於前期預處理

 

梯度問題解決方案

梯度爆炸問題方案

該問題采用截斷的方式有效避免,並且取得較好的效果。 
這里寫圖片描述

梯度彌散問題解決方案

針對該問題,有大量的解決方法,效果不一致。 
1.有效初始化+ReLU激活函數能夠得到較好效果 
2.算法上的優化,例如截斷的BPTT算法。 
3.模型上的改進,例如LSTM、GRU單元都可以有效解決長期依賴問題。 
4.在BPTT算法中加入skip connection,此時誤差可以間歇的向前傳播。 
5.加入一些Leaky Units,思路類似於skip connection

RNN模型改進

主要有兩大類思路

雙向RNN(Bi-RNN)

此時不僅可以依賴前面的上下文,還可以依賴后面的上下文。 
這里寫圖片描述

深度RNN(Deep-RNN)

有多種方式進行深度RNN的組合,左一比較常用。 
這里寫圖片描述

總結

通過該小結的總結,可以了解到 
1)RNN模型優勢以及處理問題形式。 
2)標准RNN的數學公式以及BPTT推導 
3)RNN模型訓練中的梯度問題以及如何避免


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM