原址:https://blog.csdn.net/fangqingan_java/article/details/53014085
概述
循環神經網絡(RNN-Recurrent Neural Network)是神經網絡家族中的一員,擅長於解決序列化相關問題。包括不限於序列化標注問題、NER、POS、語音識別等。RNN內容比較多,分成三個小節進行介紹,內容包括RNN基礎以及求解算法、LSTM以及變種GRU、RNN相關應用。本節主要介紹
1.RNN基礎知識介紹
2.RNN模型優化以及存在的問題
3.RNN模型變種
RNN知識點
RNN提出動機
RNN的提出可以有效解決以下問題:
- 長期依賴問題:在語言模型、語音識別中需要根據上下文進行推斷和預測,上下文的獲取可以根據馬爾科夫假設獲取固定上下文。RNN可以通過中間狀態保存上下文信息,作為輸入影響下一時序的預測。
-
編碼:可以將可變輸入編碼成固定長度的向量。和CNN相比,能夠保留全局最優特征。
計算圖展開
RNN常用以下公式獲取歷史狀態
ht=f(ht−1,xt;θ)ht=f(ht−1,xt;θ)
其中h為隱藏層,用於保存上下文信息,f是激活函數。
用圖模型可以表達為:
RNN潛在可能的展開方式如下:
1)通過隱藏層傳遞信息
1.該展開形式非常常用,主要包括三層輸入-隱藏層、隱藏層-隱藏層、隱藏層到輸入層。依賴信息通過隱藏層進行傳遞。
2.參數U、V、W為共享參數
2)輸出節點連接到下一時序序列
應用比較局限,上一時序的輸出作為下一時間點的輸入,理論上上一時間點的輸出比較固定,能夠攜帶的信息比較少。
3)只有一個輸出節點
只在最后時間點t產生輸出,往往能夠將變成的輸入轉換為固定長度的向量表示。
RNN使用形式
在使用RNN時,主要形式有4中,如下圖所示。
1.一對一形式(左一:Many to Many)每一個輸入都有對應的輸出。
2.多對一形式(左二:Many to one)整個序列只有一個輸出,例如文本分類、情感分析等。
3. 一對多形式(左三:One to Many)一個輸入產出一個時序序列,常用於seq2seq的解碼階段
4.多對多形式(左四:Many to Many)不是每一個輸入對應一個輸出,對應到變成的輸出。
RNN數學表達以及優化
RNN前向傳播
對於離散時間的RNN問題可以描述為,輸入序列
(x1,y1),(x2,y2),(x3,y3)......(xT,yT)(x1,y1),(x2,y2),(x3,y3)......(xT,yT)
其中時間參數t表示離散序列,不一定是真實時間點。
對於多分類問題,目標是最小化釋然函數
min∑t=1TL(y^(xt),yt)=min−∑tlog p(yt|x1,x2...xt)min∑t=1TL(y^(xt),yt)=min−∑tlog p(yt|x1,x2...xt)
根據上面經典的RNN網絡結構,前向傳播過程如下:
如上圖U、V、W分別表示輸入到隱藏層、隱藏層到輸出以及隱藏到隱藏層的連接參數。
1. 隱藏層節點權值:at=b+Wht−1+Uxtat=b+Wht−1+Uxt
2. 隱藏層非線性變換: ht=tanh(at)ht=tanh(at)
3. 輸出層: ot=c+Vhtot=c+Vht
4. softmax層: y^t=softmax(ot)y^t=softmax(ot)
RNN優化算法-BPTT
BPTT 是求解RNN問題的一種優化算法,也是基於BP算法改進得到和BP算法比較類似。為直觀上理解通過多分類問題進行簡單推導。
1. 優化目標,對於多分類問題,BPTT優化目標轉換最小化交叉熵:
min∑tLtLt=−∑kytklogy^tkmin∑tLtLt=−∑kyktlogy^kt
這里假設有k個類
2. 由於總的損失L為各個時序點的損失和,因此有
∂L∂Lt=1∂L∂Lt=1
3. 對於輸出層中的第i節點有
(∇otL)i=∂L∂oti=∂L∂Lt∂Lt∂oti=y^ti−1i,yt(∇otL)i=∂L∂oit=∂L∂Lt∂Lt∂oit=y^it−1i,yt
最后一步是交叉熵推導結果,步驟省略,了解softmax的都清楚。1i,yt1i,yt表示如果y^t==i則為1,否則為0
4. 隱藏層節點梯度的計算,分為兩部分,第一部分 t=T。
(∇hTL)i=∑j(∇oTL)j∂oTj∂hTi=∑j(∇oTL)jVij(∇hTL)i=∑j(∇oTL)j∂ojT∂hiT=∑j(∇oTL)jVij
通過向量的方式表達為
(∇hTL)=(∇oTL)∂oT∂hT=(∇oTL)V(∇hTL)=(∇oTL)∂oT∂hT=(∇oTL)V
5.第二部分, 中間節點 t<Tt<T,對於中間節點需要考慮t+1以及以后時間點傳播的誤差,因此計算過程如下。
(∇htL)i=∑j(∇ht+1L)j∂ht+1j∂hti+∑k(∇otL)k∂otk∂hti=隱藏層誤差反饋+輸出層誤差反饋=∑j(∇ht+1L)j∂ht+1j∂at+1j∂at+1j∂hti+∑k(∇otL)kVki=∑j(∇ht+1L)j(1−ht+1j2)Wji+∑k(∇otL)kVki=(∇ht+1L)diag((1−ht+12))Wi+(∇otL)Vi(∇htL)i=∑j(∇ht+1L)j∂hjt+1∂hit+∑k(∇otL)k∂okt∂hit=隱藏層誤差反饋+輸出層誤差反饋=∑j(∇ht+1L)j∂hjt+1∂ajt+1∂ajt+1∂hit+∑k(∇otL)kVki=∑j(∇ht+1L)j(1−hjt+12)Wji+∑k(∇otL)kVki=(∇ht+1L)diag((1−ht+12))Wi+(∇otL)Vi
通過向量表示如下:
(∇htL)=(∇ht+1L)∂ht+1∂ht+(∇otL)∂ot∂ht=(∇ht+1L)diag((1−ht+12))W+(∇otL)V(∇htL)=(∇ht+1L)∂ht+1∂ht+(∇otL)∂ot∂ht=(∇ht+1L)diag((1−ht+12))W+(∇otL)V
其中diag((1−ht+12))diag((1−ht+12))是由1−ht+1i1−hit+1的平方組成的對角矩陣。
6.根據中間結果的梯度可以推導出其他參數的梯度,結果如下
∇cL∇bL∇VL∇WL∇UL=∑t(∇toL)∂ot∂c=∑t(∇toL)=∑t(∇thL)∂ht∂b=∑t(∇thL)diag((1−ht2))=∑t(∇toL)∂ot∂V=∑t(∇toL)htT=∑t(∇thL)∂ht∂W=∑t(∇thL)diag((1−ht2))ht−1T=∑t(∇thL)∂ht∂U=∑t(∇thL)diag((1−ht2))xtT∇cL=∑t(∇otL)∂ot∂c=∑t(∇otL)∇bL=∑t(∇htL)∂ht∂b=∑t(∇htL)diag((1−ht2))∇VL=∑t(∇otL)∂ot∂V=∑t(∇otL)htT∇WL=∑t(∇htL)∂ht∂W=∑t(∇htL)diag((1−ht2))ht−1T∇UL=∑t(∇htL)∂ht∂U=∑t(∇htL)diag((1−ht2))xtT
7. 到此完成了對所有參數梯度的推導。
梯度彌散和爆炸問題
RNN訓練比較困難,主要原因在於隱藏層參數W,無論在前向傳播過程還是在反向傳播過程中都會乘上多次。這樣就會導致1)前向傳播某個小於1的值乘上多次,對輸出影響變小。2)反向傳播時會導致梯度彌散問題,參數優化變得比較困難。
可以通過梯度公式也可以看出梯度彌散或者爆炸問題。
考慮到通用性,激活函數采用f(x)代替,則對隱藏層到隱藏層參數W梯度公式如下:
∇WL=∑t(∇thL)∂ht∂W=∑t(∇thL)diag(f′(ht))ht−1∇WL=∑t(∇htL)∂ht∂W=∑t(∇htL)diag(f′(ht))ht−1
后面部分可以直接得到,下面詳細分析它的系數(∇thL)(∇htL)
1.考慮當t=T,即為最后一個節點時,根據上面的推導有
(∇hTL)=(∇oTL)∂oT∂hT=(∇oTL)V(∇hTL)=(∇oTL)∂oT∂hT=(∇oTL)V
2.當t=T-1時,(∇hT−1L)=(∇ThL)∂ht+1∂ht=(∇hTL)diag(f′(hT))W(∇hT−1L)=(∇hTL)∂ht+1∂ht=(∇hTL)diag(f′(hT))W
注這里只考慮隱藏層節點對W的誤差傳遞,沒有考慮輸出層。
3. 當t=T-2時,(∇hT−2L)=(∇T−1hL)∂hT−1∂hT−2=(∇hTL)diag(f′(hT))Wdiag(f′(hT−1))W=(∇hTL)diag(f′(hT))diag(f′(hT−1))W2(∇hT−2L)=(∇hT−1L)∂hT−1∂hT−2=(∇hTL)diag(f′(hT))Wdiag(f′(hT−1))W=(∇hTL)diag(f′(hT))diag(f′(hT−1))W2
4. 當t=k時(∇hkL)=(∇ThL)∏j=k+1T∂hj∂hj−1=(∇hTL)∏j=kTdiag(f′(hj))W(∇hkL)=(∇hTL)∏j=k+1T∂hj∂hj−1=(∇hTL)∏j=kTdiag(f′(hj))W
5.此時diag(f′(hj))Wdiag(f′(hj))W的結果是一個對角矩陣,如果其中某個元素大於1,則該值會指數倍放大;否則會以指數倍縮小。
6.因此可以看出當序列比較長,即模型有長期依賴問題時,就會產生梯度相關問題。一般情況下BPTT對於序列長度在100以內,不會暴露問題。
7.需要注意的是,如果我們的訓練樣本被人工分為子序列,且長度都較小時,不會產生梯度問題。此時比較依賴於前期預處理
梯度問題解決方案
梯度爆炸問題方案
該問題采用截斷的方式有效避免,並且取得較好的效果。
梯度彌散問題解決方案
針對該問題,有大量的解決方法,效果不一致。
1.有效初始化+ReLU激活函數能夠得到較好效果
2.算法上的優化,例如截斷的BPTT算法。
3.模型上的改進,例如LSTM、GRU單元都可以有效解決長期依賴問題。
4.在BPTT算法中加入skip connection,此時誤差可以間歇的向前傳播。
5.加入一些Leaky Units,思路類似於skip connection
RNN模型改進
主要有兩大類思路
雙向RNN(Bi-RNN)
此時不僅可以依賴前面的上下文,還可以依賴后面的上下文。
深度RNN(Deep-RNN)
有多種方式進行深度RNN的組合,左一比較常用。
總結
通過該小結的總結,可以了解到
1)RNN模型優勢以及處理問題形式。
2)標准RNN的數學公式以及BPTT推導
3)RNN模型訓練中的梯度問題以及如何避免