炙手可熱的LSTM
引言
上一講說到RNN. RNN可說是目前處理時間序列的大殺器,相比於傳統的時間序列算法,使用起來更方便,不需要太多的前提假設,也不需太多的參數調節,更重要的是有學習能力,因此是一種'智能'算法.前面也說到, 不只時間序列,在很多領域,特別是涉及序列數據的,RNN的表現總是那么的'搶眼'.不過,在這搶眼的過程中, 沖鋒最前面的可不是簡單的RNN(或者說最原始的RNN), 真正把RNN這個牌子''做大做強''的是: 長短時記憶循環神經網絡(Long-Short Term Memory RNN, LSTM).
RNN的問題
傳統的(或稱原始的)RNN理論上是可以記憶任意長度的時間序列,比如你把一本<紅樓夢>給她, 理論上她也是可以記憶的. 但, 理論和實際是有差距滴~.
在應用過程中,發現RNN對長時記憶的能力比較弱, 也就是RNN的記性不太好,對於長時間的東西她就有點記不清了,而幾乎只會關注最近一段時間的信息. 也就是說, 當你給她<紅樓夢>之后,夢想着她如何給你講講''大觀園''的趣事, 她卻不知所雲地來了句' 說到辛酸處,荒唐愈可悲.由來同一夢,休笑世人痴.' —— 前面的完全忘光了! 為什么?辛辛苦苦地訓練她,可她竟這樣地不念舊情?! 請不要怪她, 她的''內在邏輯'',上她無法控制的.這個所謂的'內在邏輯'就是指數函數(Exponential function).
恐怖的指數函數
指數函數大家應該都記得:
\[y = a^x \qquad x \in R \]
這就是指數函數,是不是其貌不揚? 是不是純天然無公害? 你錯了, 它展現的實力可是爆炸性的(explosive). 大家聽過'棋盤放米'的故事吧: 第一格放1粒米(注意是一粒可不是一斤哦~),第二格放 2 粒米,第三個格放 4 粒米,…, 最后國王付不起了... 對於8x8的棋盤(國際象棋), 一共要放 (\(2^{64} - 1\)) 粒米, 即18,446,744,073,709,551,615粒. 一千克米的粒數大約在40000-60000之間, 就算最多的60000 粒, 18,446,744,073,709,551,615 / 60000 約 等於 307445734561826千克. 2016年,大宗糧油全球總產量27.84億噸,取整算28億噸, \(307445734561826 \div (28*10^{11}) \approx 110\) 年. 哪個國王支付得起?
不夠直觀? 好,再舉個例子: 拿一張A4紙,對折,對折,再對折,…嘗試一下,你能折幾次? 世界記錄只有13次! what?! 這么少?! 對的,這就是指數的威力.
\(10^{100}\), 10 的100次冪叫做Googol(哈哈,被你發現了, Google名字的由來...) ,這個Googol 可以說是宇宙極限了. 想象一下, 宇宙萬物都是由基本粒子組成,你所知道的最小粒子是什么? 電子 還是誇克? 科學家估算全宇宙的基本粒子數(都算上) 最多不超過\(10^{80}\)個!
最后一個例子 \(0.9^{20}\approx 0.12\).
通過以上,就好理解為什么RNN健忘了.回憶下上講的公式.
\[\begin{array} \\ s_t &=& f(s_{t-1},x_t)\\ &=& f(W s_{t-1}+Ux_{t})\\ &= &f(Wf(s_{t-2},x_{t-1})+Ux_t)\\ &=& ...\\ &= &f(W(f(W(...(f(Ws_0+Ux1))))))\\ \end{array} \]
看! 有多個W(矩陣)相乘, 般權重范圍在(-1,1)之間,來個指數冪,一下子就沒了...也就是之前的信息不會對當前或未來的信息產生影響了. 也就是說RNN失去了記憶的能力了.
專業一點的說法叫做梯度消失(gradient vanishing); 如果權重大於1就會產生梯度爆炸(gradient explosion)(比較少見).
梯度消失*
上面是較通俗的說法, 其實從這個問題的名字就可看出,其切入點是梯度(gradient).
回顧上階公式:
\[\nabla_W L = \sum_{t}^{T} \boldsymbol{\delta_t\odot s_{t-1}} \]
其中
\[\delta_t = \frac{\partial L}{\partial z_t} = \delta_T \Pi_{i = t}^{T-1} W\odot f'(z_i) = \delta_T * W*f'(z_t)...*W*f'(z_{T-1}) \]
用模的上界表示:
\[||\delta_t||\le ||\delta_T||. (||W||.||\bar{f}||)^{T-t} \]
其中 \(\bar{f}\)表示的是\(f'(z_t)\)的模上界. 這樣就了然了, 指數形式的出現表明前面(T-t 比較大時)的信息對於梯度的更新沒有貢獻,也就是無論之前信息怎樣,最終的權重更新(學習)都不會受到影響. 因為前面信息的梯度由於指數邏輯的存在,使梯度趨近於0 — 消失了.也就是說,RNN忘記了.比如 某個\(w_{i,j}\)的值在t時刻是的梯度0.3(實際中w的一般量級,甚至更小), t-1時 大約為0.3* 0.3 = 0.09了,t-2時約為0.09* 0.3 = 0.027,t-3; 0.027* 0.3 = 0.0081,t-4: 0.0081 *0.3 = 0.00243,…不用寫了,在反向傳播4次就已經很小了,而RNN時間深度再深一些,比如15,比如20,比如30,…, 可以想見,之前的信息RNN直接忽略了.
不只在RNN中,其實梯度消失及梯度爆炸在深度學習領域一直是一個比較頭疼的問題,這也是深度網絡難以訓練的主要原因.只不過在深度網絡中,是因為層數的增多導致產生類似指數形式的連續乘積.
解決方案
出現梯度消失(主要)與爆炸問題后,有很多解決方法提出來,比如設計更好的初始化權重,限制權重范圍等等. 這種''通用''的方法的作用有限. 在RNN中有人提出設計隱藏單元用來儲存信息,稱為儲層計算(Reservoir Computing),比如回聲狀態網絡(Echo State Network, ESN).也有人提出在不同的時間粒度上處理數據,不同的時間處理單元稱為滲透單元(Leaky Unit).但目前效果最好的,通用性最高的還是門限(Gated) RNN.其中最火的就是LSTM(Long-Short Term Memory)及GRU(Gated Recurrent Unit).
LSTM
設計初衷
首先,拋開恐怖的指數函數不談,咱們先想象一個場景:假設你很喜歡古龍的小說,他的小說你都看了好多遍.現在給一篇他的小說,比如<七種武器>里的<霸王槍>, 篇幅不長,故事也不太復雜,讓你閱讀. 幾個小時后,或者大方點,第二天,我來找你,讓你一字不落地背出第一章<落日照大旗>. 你一定會問我我是不是凱丁蜜(Are you kidding me?),然后我會說我是斯爾瑞爾斯(I'm serious.). 最后你會承認你背不出. 但我要問你:誰是丁喜?百里長青與丁喜是什么關系?這本小說講了一個什么故事?你一定滔滔不絕.
背不出一章內容,但卻能說出整本小說的故事梗概, 是因為我們會提取主要信息, 不會對信息'一視同仁',懂得取舍.有些信息比如環境描寫看看就過去了,一般不會刻意去記憶.但有些重要線索,比如誰殺了誰等等這樣的信息我們會記住.
回過頭來再看RNN,繼續忽略恐怖的指數函數,直觀的理解一下: RNN讀取的信息,對信息一視同仁:經過處理的信息,RNN認為這些信息的任何一部分都對接下來的信息有影響,全部都拋給接下來處理的程序.對這些信息,RNN進行同樣的處理.,造成大量無用信息冗余,浪費大量記憶空間,導致關鍵信息無法突出,更多的信息又無法存儲.從而產生較前面的信息RNN記不住的問題.這才是''本質原因". 神馬指數函數只是'劊子手'而已.
其他方法都只是從表象處理問題(針對梯度消失,或指數函數的連續乘法),或者雖針對本質原因但方法不對頭. 而門限RNN正是針對信息的重要性設計的.
LSTM原理
考慮重要性,那就自然而然的產生兩種時態信息. 一種就是長時態(Long term state)信息. 此信息包含'趨勢'信息或'主旨'信息,是剔除冗余信息后,對未來信息真正產生作用的信息.比如小說中的主旨大意,新聞要點等等.另一種短時態信息(Short term state). 此類信息是最直接地,對未來信息產生影響的信息. 比如'今天真熱啊, 我得吹吹(空調)', ''吹吹'直接導致'空調''或'風扇'的產生,而不是可樂,'涼水澡'等等.
相比傳統RNN的'一視同仁', 兩種時態信息的區分,致使長時態信息不會被短時信息所淹沒.

圖1: 信息分態
對於兩種時態信息, LSTM是如何提取重要信息的呢? 顧名思義,通過門(gate)來'提取'的.

圖2:信息流門限控制
上圖中, \(C_t\) 代表長時態信息 \(C_{t-1}\) 為前一個時刻的長時態信息),而 \(C'_t\) 則代表短時態信 息, \( h_t\) 為經過LSTM單元后的輸出信息,三條線上的開關,即為門限.圖中展示的三種門分別為:
- 前一時刻長態信息與當前時刻長態信息之間控制門: 遺忘門(Forget gate);
- 當前時刻即短態信息與長態信息之間控制門: 輸入門(Input gate);
- 當前信息(長,短匯總后)與輸出態信息之間控制門:輸出門(Output gate).
遺忘門控制的是歷史信息有多少對現在,對未來有影響,即有多少是可以繼續保留在長態信息的; 輸入門控制的是輸入信息有多少可以加入到長態信息中去;輸出門控制的是匯總后的信息有多少是可以作為當前輸出的信息.
門限控制*
門的設計根據以上信息也就不難設計:
\[gate(x) = \sigma(Wx + b) \]
其中 x 表示的是門的輸入, 而 \(\sigma\) 表示的是門限激活(控制)函數,一般為(或當前比較流行)sigmoid函數.設 t 時刻的遺忘門,輸入門及輸出門分別用 \(f_t, i_t,o_t\) 表示. 則三種的表示方式:
\[\begin{array} \\ f_t = \sigma(h_{t-1},x_t) = \sigma(W_{f,h}h_{t-1} + W_{f,x}x+b_f)\\ i_t = \sigma(h_{t-1},x_t) = \sigma(W_{i,h}h_{t-1} + W_{i ,x}x+b_i) \\ o_t = \sigma(h_{t-1},x_t) = \sigma(W_{o,h}h_{t-1} + W_{o,x}x+b_o)\\ \end{array} \]
其中 W 為權重,如 \(W_{f,h}\) 為遺忘門對應的上一時態輸出信息的權重,其他同理. b 為偏置.
門的輸入 (x) 又是什么呢?即門的開關取決於什么呢?沒錯,是單元的輸入信息,當前時刻的輸入信息包括前一時刻的輸出(\(h_{t-1}\))以及當前時刻的外部信息輸入 (\(x_t\)). 用 \(C'_t\) 表示當前輸入則:
\[C'_t = tanh(W_{C,h} h_{t-1}+W_{C,x}x + b_C) \]
試下吧.其中 tanh() 為雙曲正切函數,它可以理解成為激活函數,當然其激活函數也是可以的,只不過當前流行(或者說當前效果好)tanh(),下同.
以上,門與輸入都有了,那 t 時刻的狀態信息(Ct)就可以寫出來了(觀察圖2):
\[C_t = f_t\odot C_{t-1} + i_t \odot C'_t \]
可見,當門的值為1時,門屬於完全開放狀態,所有信息都可以通過, 而門的值為0 則表示關閉狀態,所有信息都不能通過, 而正常情況下則是(0,1)之間,即對信息是有取舍的.
t 時刻的狀態信息產生,那么 t 時刻的輸出 (ht) 就可以得出了:
\[h_t = o_t \odot tanh(C_t) \]
至此.LSTM單元就構建完成了.
LSTM 的 BPTT
對模型訓練,要更新的參數即為權重(與偏置),其中權重的設置有四處,三個門與輸入的端.
設加權輸入為 z, 則:
\[\begin{array} \\ z_{f,t} = W_{f,h} h_{t-1} + W_{f,x} x_t + b_f\\ z_{i,t} = W_{i,h} h_{t-1} + W_{i,x} x_t + b_i\\ z_{o,t} = W_{o,h} h_{t-1} + W_{o,x} x_t + b_o\\ z_{f,t} = W_{c',h} h_{t-1} + W_{C',x} x_t + b_{C'}\\ \end{array} \]
設其對應的誤差項為 \(\delta\), 則:
\[\begin{array}\\ \delta_{f,t} = \frac{\partial L}{\partial{ z_{f,t}}}\\ \delta_{i,t} = \frac{\partial L}{\partial{ z_{i,t}}}\\ \delta_{o,t} = \frac{\partial L}{\partial{ z_{o,t}}}\\ \delta_{C',t} = \frac{\partial L}{\partial{ z_{C',t}}}\\ \end{array} \]
其中 L 為損失函數.
設 t 時刻的誤差項為 \(\delta_t\):
\[\delta_t = \frac{\partial L}{\partial h_t} \]
則 t-1 時刻的誤差項為:
\[\delta_{t-1} = \frac{\partial L}{\partial h_{t-1}} =\frac{\partial L}{\partial h_t} \frac{\partial h_{t}}{\partial h_{t-1}} = \delta_t \frac{\partial h_{t}}{\partial h_{t-1}} \]
回顧下(7-10)四式:
\[\begin{array}\\ \frac{\partial h_{t}}{\partial h_{t-1}} & = & \frac{\partial h_t}{\partial o_t} \frac{\partial o_t}{\partial z_{o,t}} \frac{\partial z_{o,t}}{\partial h_{t-1}}\\ &&+\frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial f_t} \frac{\partial f_t}{\partial z_{f,t}} \frac{\partial z_{f,t}}{\partial h_{t-1}}\\ && + \frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial i_t} \frac{\partial i_t}{\partial z_{i,t}} \frac{\partial z_{i,t}}{\partial h_{t-1}} \\ &&+ \frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial X'_t} \frac{\partial C'_t}{\partial z_{C',t}} \frac{\partial z_{C',t}}{\partial h_{t-1}} \end{array} \]
把(15)式代入(14)式, 可得:
\[\delta_{t-1} = \delta_{o,t}\frac{\partial z_{o,t}}{\partial h_{t-1}}+\delta_{f,t}\frac{\partial z_{f,t}}{\partial h_{t-1}}+\delta_{i,t}\frac{\partial z_{i,t}}{\partial h_{t-1}}+\delta_{C',t}\frac{\partial z_{C',t}}{\partial h_{t-1}} \]
其中用到了:
\[\begin{array}\\ \delta_{f,t} = \frac{\partial L}{\partial{ z_{f,t}}} = \frac{\partial L}{\partial h_t} \frac{\partial h_{t}}{\partial C_{t}}\frac{\partial C_t}{\partial f_t} \frac{\partial f_t}{\partial z_{f,t}} \\ \delta_{i,t} = \frac{\partial L}{\partial{ z_{i,t}}} = \frac{\partial L}{\partial h_t} \frac{\partial h_{t}}{\partial C_{t}}\frac{\partial C_t}{\partial i_t} \frac{\partial i_t}{\partial z_{i,t}} \\ \delta_{o,t} = \frac{\partial L}{\partial{ z_{o,t}}} = \frac{\partial L}{\partial h_t} \frac{\partial h_{t}}{\partial C_{t}}\frac{\partial C_t}{\partial o_t} \frac{\partial o_t}{\partial z_{o,t}} \\ \delta_{C',t} = \frac{\partial L}{\partial{ z_{C',t}}} = \frac{\partial L}{\partial h_t} \frac{\partial h_{t}}{\partial C_{t}}\frac{\partial C_t}{\partial C'_t} \frac{\partial C'_t}{\partial z_{C',t}} \\ \end{array} \]
於是:
\[\delta_{t-1} = \delta_{o,t} W_{o,h}+\delta_{f,t} W_{f,h}+\delta_{i,t} W_{i,h}+\delta_{C',t} W_{C',h} \]
其中用到了:
\[\begin{array}\\ \frac{\partial z_{f,t}}{\partial h_{t-1}} =\frac{\partial (W_{f,h} h_{t-1}+W_{C',x} x_t + b_f)}{\partial h_{t-1}} = W_{f,h} \\ \frac{\partial z_{i,t}}{\partial h_{t-1}} =\frac{\partial (W_{i,h} h_{t-1}+W_{C',x} x_t + b_{i})}{\partial h_{t-1}} = W_{i,h} \\ \frac{\partial z_{o,t}}{\partial h_{t-1}} =\frac{\partial (W_{o,h} h_{t-1}+W_{o,x} x_t + b_{o})}{\partial h_{t-1}} = W_{o,h} \\ \frac{\partial z_{C',t}}{\partial h_{t-1}} =\frac{\partial (W_{C',h} h_{t-1}+W_{C',x} x_t + b_{C'})}{\partial h_{t-1}} = W_{C',h} \\ \end{array} \]
以上是誤差在時域上的傳播.
接下來探討傳播到上一層(l-1):
\[\delta^{l-1}_t = \frac{\partial L }{\partial z_t^{l-1}} \]
t 時刻的輸入 \(x_t\) :
\[x_t^l = f^{l-1}(z_{t}^{l-1}) \]
在 l 層, $z_{f,t}^l,z_{i,t}^l,z_{o,t}^l,z_{C',t}^l $ 均為\(x_t\) 的函數,則
\[\begin{array}\\ \delta^{l-1}_t & = & \frac{\partial L }{\partial z_t^{l-1}} \\ &=& \frac{\partial L}{\partial z_{f,t}^l} \frac{\partial z_{f,t}^l}{\partial x_t^l}\frac{\partial x_t}{\partial z_{t}^{l-1}} \\ &&+ \frac{\partial L}{\partial z_{i,t}^l} \frac{\partial z_{i,t}^l}{\partial x_t^l}\frac{\partial x_t}{\partial z_{t}^{l-1}} \\ &&+ \frac{\partial L}{\partial z_{o,t}^l} \frac{\partial z_{o,t}^l}{\partial x_t^l}\frac{\partial x_t}{\partial z_{t}^{l-1}} \\ &&+ \frac{\partial L}{\partial z_{C',t}^l} \frac{\partial z_{C',t}^l}{\partial x_t^l}\frac{\partial x_t}{\partial z_{t}^{l-1}} \\ &=& \delta_{f,t}W_{f,x}\odot f'(z_{t}^{l-1}) \\ &&+ \delta_{i,t}W_{i,x}\odot f'(z_{t}^{l-1}) \\ &&+ \delta_{o,t}W_{o,x}\odot f'(z_{t}^{l-1}) \\ && +\delta_{C',t}W_{C',x}\odot f'(z_{t}^{l-1}) \\ & = & ( \delta_{f,t} W_{f,x}+ \delta_{i,t}W_{i,x}+ \delta_{o,t}W_{o,x}+ \delta_{C',t}W_{C',x})\odot f'(z_{t}^{l-1}) \end{array} \]
有以上誤差項, 梯度求解就簡單多了,t 時刻的 \(W_{f,h},W_{i,h},W_{o,h},W_{C',h}\) 分別為:
\[\begin{array}\\ \frac{\partial L}{\partial W_{f,h,t}} = \frac{\partial L}{\partial z_{f,t}}\frac{\partial z_{f,t}}{\partial W_{f,h,t}}= \delta_{f,t} h_{t-1}\\ \frac{\partial L}{\partial W_{i,h,t}} = \frac{\partial L}{\partial z_{i,t}}\frac{\partial z_{i,t}}{\partial W_{i,h,t}}= \delta_{i,t} h_{t-1}\\ \frac{\partial L}{\partial W_{o,h,t}} = \frac{\partial L}{\partial z_{o,t}}\frac{\partial z_{o,t}}{\partial W_{o,h,t}}= \delta_{o,t} h_{t-1}\\ \frac{\partial L}{\partial W_{C',h,t}} = \frac{\partial L}{\partial z_{C',t}}\frac{\partial z_{C',t}}{\partial W_{C',h,t}}= \delta_{C',t} h_{t-1}\\ \end{array} \]
各個時刻的梯度之和即最終梯度:
\[\begin{array}\\ \frac{\partial L}{\partial W_{f,h}} = \sum_{t =1}^T \delta_{f,t}h_{t-1} \\ \frac{\partial L}{\partial W_{i,h}} = \sum_{t =1}^T \delta_{i,t}h_{t-1} \\ \frac{\partial L}{\partial W_{o,h}} = \sum_{t =1}^T \delta_{o,t}h_{t-1} \\ \frac{\partial L}{\partial W_{C',h}} = \sum_{t =1}^T \delta_{C',t}h_{t-1} \\ \end{array} \]
對於偏置:
\[\begin{array}\\ \frac{\partial L}{\partial b_{f,t}} = \frac{\partial L}{\partial z_{f,t}} \frac{\partial z_{f,t}}{\partial b_{f,t}} = \delta_{f,t} \\ \frac{\partial L}{\partial b_{i,t}} = \frac{\partial L}{\partial z_{i,t}} \frac{\partial z_{i,t}}{\partial b_{i,t}} = \delta_{i,t} \\ \frac{\partial L}{\partial b_{o,t}} = \frac{\partial L}{\partial z_{o,t}} \frac{\partial z_{o,t}}{\partial b_{o,t}} = \delta_{o,t} \\ \frac{\partial L}{\partial b_{C',t}} = \frac{\partial L}{\partial z_{C',t}} \frac{\partial z_{C',t}}{\partial b_{C',t}} = \delta_{C',t} \end{array} \]
最終梯度:
\[\begin{array}\\ \frac{\partial L}{\partial b_{f,t}} = \sum_{t =1}^T \delta_{f,t}\\ \frac{\partial L}{\partial b_{i,t}} = \sum_{t =1}^T \delta_{i,t}\\ \frac{\partial L}{\partial b_{o,t}} = \sum_{t =1}^T \delta_{o,t}\\ \frac{\partial L}{\partial b_{C',t}} = \sum_{t =1}^T \delta_{C',t}\\ \end{array} \]
最后:
\[\begin{array}\\ \frac{\partial L}{\partial W_{f,x}} = \frac{\partial L}{\partial z_{f,t}}\frac{\partial z_{f,t}}{\partial W_{f,x}} = \delta_{f,x}x_t\\ \frac{\partial L}{\partial W_{i,x}} = \frac{\partial L}{\partial z_{i,t}}\frac{\partial z_{i,t}}{\partial W_{i,x}} = \delta_{i,x}x_t\\ \frac{\partial L}{\partial W_{o,x}} = \frac{\partial L}{\partial z_{o,t}}\frac{\partial z_{o,t}}{\partial W_{o,x}} = \delta_{o,x}x_t\\ \frac{\partial L}{\partial W_{C,x}} = \frac{\partial L}{\partial z_{C',t}}\frac{\partial z_{C',t}}{\partial W_{C',x}} = \delta_{C',x}x_t\\ \end{array} \]
以上就是 LSTM 的 BPTT, 似乎很多公式,但其實四種模式都是一樣的,怕大家混淆就都寫上了,只不過這樣看着會很多的樣子.
參考文獻:
1: Deep Learning, 2016, Ian Goodfellow, Yoshua Bengio, Aaron Courville.
2: Neural Networks and Deep Learning, 2016, Michael Nielsen.
3: Understanding LSTM, 2015, Colah's blog.
4: A Critical Review of Recurrent Neural Networks for Sequence Learning, 2015, Zachary C. Lipton et al.
5: 零基礎入門深度學習(6)- 長短時記憶網絡(LSTM) 2017, hanbingtao.