Recurrent Neural Network
RNN擅長處理序列問題。下面我們就來看看RNN的原理。
可以這樣描述:如上圖所述,網絡的每一個output都會對應一個memory單元用於存儲這一時刻網絡的輸出值,
然后這個memory會作為下一時刻輸入的一部分傳入RNN,如此循環下去。
下面來看一個例子。
假設所有神經元的weight都為1,沒有bias,所有激勵函數都是linear,memory的初始值為0.
輸入序列[1,1],[1,1],[2,2].....,來以此計算輸出。
對輸入[1,1],output為1×1+1×1 + 0×1 = 2->2*1+2*1 = 4,最后輸出為[4,4],然后將[4,4]存入memory單元,作為下一時刻的部分輸入。
最后得到的輸出序列是這樣的。
而如果每次輸入的序列不同,最后的輸出序列也會不一樣。
在RNN中,每次都是使用相同的網絡結構,只是每次的輸入和memory會不同。
這樣就使我們在句子分析中,能夠辨別同一個詞出現在不同位置的時候的不同意思。
當然RNN也可以是深層的網絡。這里會有兩種不同的RNN類型Elman和Jordan。
還有雙向的RNN,可以兼顧句子的前后部分。
Long Short-term Memory (LSTM)
上面就是一個LSTM的cell的結構,每個cell有4個input, 和1個output。
其中3個input是3個gate。input gate 控制真的的input是否輸入網絡;
forget gate 控制memory是否要記得之前的時序信息;
output gate 控制是否輸出當前的得到的output。
RNN cell的3個gate輸入分別是zi,zf,zo,都是標量數據scalar,都需要通過一個激勵函數f,f通常是sigmoid,可以將正負無窮的區間壓縮到0~1,
模擬gate的開關。input z也是一個scalar,通過g(z)與f(zi)相乘,如果input gate關閉,就是f(zi)=0,那么輸入g(z)就沒有進入到cell中。
然后繼續往下走,有c' = g(x)*f(zi) + c*f(zf), 如果forget gate的f(zf)為1,就相當於記得memory的值,可以加上c,然后將c‘存入到memory中。
真正的輸出會是a = h(c')*f(zo),若f(zo)=0,則不輸出當前值。
因為在RNN中需要處理4個input,所以參數會是傳統前向傳播網絡的4倍。
假如上一時刻memory的值為ct-1,是一個vector,輸入為xt,然后轉化為4個input向量。
首先zi通過activation function與z相乘,zf通過activation function與memory中的ct-1相乘,然后把這兩個結果相加,得到新的memory中的值ct,
zo通過activation function與剛剛的輸出相乘得到最終的輸出yt。
但是通常RNN還會將上一時刻的輸出ht和memory中的值ct,再加上輸入xt一起作為輸入來操控RNN。如上圖所示。
以上就是RNN和LSTM的基本原理。
RNN的訓練
DNN和CNN都可以使用gradient decent 來訓練,RNN也可以。
RNN是基於時間序列的Backpropagation through time(BPTT)。
但是訓練結果通常是這樣的:
RNN的total loss會發生劇烈的震盪,相當不穩定,無法收斂。
這是因為RNN的error surface 很崎嶇,有平坦的地方,也有梯度很陡的地方。
這樣就是梯度的變化很大,有時候參數w很小的更新就會造成很大的梯度變化,導致loss劇烈震盪。(clipping)
但是造成這種情況的原因並不是我們使用了sigmoid,而且在RNN上使用Relu往往會得到更壞的結果。
來看一個簡單的例子,一個最簡單的RNN。
參數w=1時,輸出為1;參數w=1.01時,輸出為20000.
w的變化對輸出值有不同程度的影響,致使learning rate不好選擇。
綜上所述,RNN的訓練很困難。
RNN很難訓練,那有什么解決辦法嗎?那就是LSTM。
LSTM可以去掉error surface上“平坦”的地方,使梯度不至於特別小,這樣可以解決梯度彌散的問題。
訓練LSTM的時候,需要將learning rate 調整到很小。
簡單的RNN每次訓練后都會將memory中的信息覆蓋,放入當前結果到memory中。
而LSTM是不一樣的,它每次都將memory中的值乘上forget gate的值再加上input,放入cell,所以如果當前的w對memory有影響,那么這個影響將持續存在,在RNN中每個時刻的memory都會被覆蓋,w的影響都不再存在。所以只要forget gate 一直打開,memory中的值都會被加到新的input中,而不會消失,這樣就解決了gradient valish的問題。