https://www.cnblogs.com/liujshi/p/6159007.html
LSTM的推導與實現
前言
最近在看CS224d,這里主要介紹LSTM(Long Short-Term Memory)的推導過程以及用Python進行簡單的實現。LSTM是一種時間遞歸神經網絡,是RNN的一個變種,非常適合處理和預測時間序列中間隔和延遲非常長的事件。假設我們去試着預測‘I grew up in France...(很長間隔)...I speak fluent French’最后的單詞,當前的信息建議下一個此可能是一種語言的名字(因為speak嘛),但是要准確預測出‘French’我們就需要前面的離當前位置較遠的‘France’作為上下文,當這個間隔比較大的時候RNN就會難以處理,而LSTM則沒有這個問題。
LSTM的原理
為了弄明白LSTM的實現,我下載了alex的原文,但是被論文上圖片和公式弄的暈頭轉向,無奈最后在網上收集了一些資料才總算弄明白。我這里不介紹就LSTM的前置RNN了,不懂的童鞋自己了解一下吧。
LSTM的前向過程
首先看一張LSTM節點的內部示意圖:
圖片來自一篇講解LSTM的blog(http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
這是我認為網上畫的最好的LSTM網絡節點圖(比論文里面畫的容易理解多了),LSTM前向過程就是看圖說話,關鍵的函數節點已經在圖中標出,這里我們忽略了其中一個tanh計算過程。
這里ϕ(x)=tanh(x),σ(x)=11+e−xϕ(x)=tanh(x),σ(x)=11+e−x,x(t),h(t)x(t),h(t)分別是我們的輸入序列和輸出序列。如果我們把x(t)x(t)與h(t−1)h(t−1)這兩個向量進行合並:
那么可以上面的方程組可以重寫為:
其中f(t)f(t)被稱為忘記門,所表達的含義是決定我們會從以前狀態中丟棄什么信息。i(t),g(t)i(t),g(t)構成了輸入門,決定什么樣的新信息被存放在細胞狀態中。o(t)o(t)所在位置被稱作輸出門,決定我們要輸出什么值。這里表述的不是很准確,感興趣的讀者可以去http://colah.github.io/posts/2015-08-Understanding-LSTMs/ NLP這塊我也不太懂。
前向過程的代碼如下:
def bottom_data_is(self, x, s_prev = None, h_prev = None): # if this is the first lstm node in the network if s_prev == None: s_prev = np.zeros_like(self.state.s) if h_prev == None: h_prev = np.zeros_like(self.state.h) # save data for use in backprop self.s_prev = s_prev self.h_prev = h_prev # concatenate x(t) and h(t-1) xc = np.hstack((x, h_prev)) self.state.g = np.tanh(np.dot(self.param.wg, xc) + self.param.bg) self.state.i = sigmoid(np.dot(self.param.wi, xc) + self.param.bi) self.state.f = sigmoid(np.dot(self.param.wf, xc) + self.param.bf) self.state.o = sigmoid(np.dot(self.param.wo, xc) + self.param.bo) self.state.s = self.state.g * self.state.i + s_prev * self.state.f self.state.h = self.state.s * self.state.o self.x = x self.xc = xc
LSTM的反向過程
LSTM的正向過程比較容易,反向過程則比較復雜,我們先定義一個loss function l(t)=f(h(t),y(t)))=||h(t)−y(t)||2l(t)=f(h(t),y(t)))=||h(t)−y(t)||2,h(t),y(t)h(t),y(t)分別為輸出序列與樣本標簽,我們要做的就是最小化整個時間序列上的l(t)l(t),即最小化
其中TT代表整個時間序列,下面我們通過LL來計算梯度,假設我們要計算dLdwdLdw,其中ww是一個標量(例如是矩陣WgxWgx的一個元素),由鏈式法則可以導出
其中hi(t)hi(t)是第i個單元的輸出,MM是LSTM單元的個數,網絡隨着時間t前向傳播,hi(t)hi(t)的改變不影響t時刻之前的loss,我們可以寫出:
為了書寫方便我們令L(t)=∑Ts=tl(s)L(t)=∑s=tTl(s)來簡化我們的書寫,這樣L(1)L(1)就是整個序列的loss,重寫上式有:
這樣我們就可以將梯度重寫為:
我們知道L(t)=l(t)+L(t+1)L(t)=l(t)+L(t+1),那么dL(t)dhi(t)=dl(t)dhi(t)+dL(t+1)dhi(t)dL(t)dhi(t)=dl(t)dhi(t)+dL(t+1)dhi(t),這說明得到下一時序的導數后可以直接得出當前時序的導數,所以我們可以計算TT時刻的導數然后往前推,在TT時刻有dL(T)dhi(T)=dl(T)dhi(T)dL(T)dhi(T)=dl(T)dhi(T)。
def y_list_is(self, y_list, loss_layer): """ Updates diffs by setting target sequence with corresponding loss layer. Will *NOT* update parameters. To update parameters, call self.lstm_param.apply_diff() """ assert len(y_list) == len(self.x_list) idx = len(self.x_list) - 1 # first node only gets diffs from label ... loss = loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) # here s is not affecting loss due to h(t+1), hence we set equal to zero diff_s = np.zeros(self.lstm_param.mem_cell_ct) self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) idx -= 1 ### ... following nodes also get diffs from next nodes, hence we add diffs to diff_h ### we also propagate error along constant error carousel using diff_s while idx >= 0: loss += loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h += self.lstm_node_list[idx + 1].state.bottom_diff_h diff_s = self.lstm_node_list[idx + 1].state.bottom_diff_s self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) idx -= 1 return loss
從上面公式可以很容易理解diff_h的計算過程。這里的loss_layer.bottom_diff定義如下:
def bottom_diff(self, pred, label): diff = np.zeros_like(pred) diff[0] = 2 * (pred[0] - label) return diff
該函數結合上文的loss function很明顯。下面來推導dL(t)ds(t)dL(t)ds(t),結合前面的前向公式我們可以很容易得出s(t)s(t)的變化會直接影響h(t)h(t)和h(t+1)h(t+1),進而影響L(t)L(t),即有:
因為h(t+1)h(t+1)不影響l(t)l(t)所以有dL(t)dhi(t+1)=dL(t+1)dhi(t+1)dL(t)dhi(t+1)=dL(t+1)dhi(t+1),因此有:
同樣的我們可以通過后面的導數逐級反推得到前面的導數,代碼即diff_s的計算過程。
下面我們計算dL(t)dhi(t)∗dhi(t)dsi(t)dL(t)dhi(t)∗dhi(t)dsi(t),因為h(t)=s(t)∗o(t)h(t)=s(t)∗o(t),那么dL(t)dhi(t)∗dhi(t)dsi(t)=dL(t)dhi(t)∗oi(t)=oi(t)[diff_h]dL(t)dhi(t)∗dhi(t)dsi(t)=dL(t)dhi(t)∗oi(t)=oi(t)[diff_h],即dL(t)dsi(t)=o(t)[diff_h]i+[diff_s]idL(t)dsi(t)=o(t)[diff_h]i+[diff_s]i,其中[diff_h]i,[diff_s]i[diff_h]i,[diff_s]i分別表述當前t時序的dL(t)dhi(t)dL(t)dhi(t)和t+1時序的dL(t)dsi(t)dL(t)dsi(t)。同樣的,結合上面的代碼應該比較容易理解。
下面我們根據前向過程挨個計算導數: