https://www.cnblogs.com/liujshi/p/6159007.html
LSTM的推导与实现
前言
最近在看CS224d,这里主要介绍LSTM(Long Short-Term Memory)的推导过程以及用Python进行简单的实现。LSTM是一种时间递归神经网络,是RNN的一个变种,非常适合处理和预测时间序列中间隔和延迟非常长的事件。假设我们去试着预测‘I grew up in France...(很长间隔)...I speak fluent French’最后的单词,当前的信息建议下一个此可能是一种语言的名字(因为speak嘛),但是要准确预测出‘French’我们就需要前面的离当前位置较远的‘France’作为上下文,当这个间隔比较大的时候RNN就会难以处理,而LSTM则没有这个问题。
LSTM的原理
为了弄明白LSTM的实现,我下载了alex的原文,但是被论文上图片和公式弄的晕头转向,无奈最后在网上收集了一些资料才总算弄明白。我这里不介绍就LSTM的前置RNN了,不懂的童鞋自己了解一下吧。
LSTM的前向过程
首先看一张LSTM节点的内部示意图:
图片来自一篇讲解LSTM的blog(http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
这是我认为网上画的最好的LSTM网络节点图(比论文里面画的容易理解多了),LSTM前向过程就是看图说话,关键的函数节点已经在图中标出,这里我们忽略了其中一个tanh计算过程。
这里ϕ(x)=tanh(x),σ(x)=11+e−xϕ(x)=tanh(x),σ(x)=11+e−x,x(t),h(t)x(t),h(t)分别是我们的输入序列和输出序列。如果我们把x(t)x(t)与h(t−1)h(t−1)这两个向量进行合并:
那么可以上面的方程组可以重写为:
其中f(t)f(t)被称为忘记门,所表达的含义是决定我们会从以前状态中丢弃什么信息。i(t),g(t)i(t),g(t)构成了输入门,决定什么样的新信息被存放在细胞状态中。o(t)o(t)所在位置被称作输出门,决定我们要输出什么值。这里表述的不是很准确,感兴趣的读者可以去http://colah.github.io/posts/2015-08-Understanding-LSTMs/ NLP这块我也不太懂。
前向过程的代码如下:
def bottom_data_is(self, x, s_prev = None, h_prev = None): # if this is the first lstm node in the network if s_prev == None: s_prev = np.zeros_like(self.state.s) if h_prev == None: h_prev = np.zeros_like(self.state.h) # save data for use in backprop self.s_prev = s_prev self.h_prev = h_prev # concatenate x(t) and h(t-1) xc = np.hstack((x, h_prev)) self.state.g = np.tanh(np.dot(self.param.wg, xc) + self.param.bg) self.state.i = sigmoid(np.dot(self.param.wi, xc) + self.param.bi) self.state.f = sigmoid(np.dot(self.param.wf, xc) + self.param.bf) self.state.o = sigmoid(np.dot(self.param.wo, xc) + self.param.bo) self.state.s = self.state.g * self.state.i + s_prev * self.state.f self.state.h = self.state.s * self.state.o self.x = x self.xc = xc
LSTM的反向过程
LSTM的正向过程比较容易,反向过程则比较复杂,我们先定义一个loss function l(t)=f(h(t),y(t)))=||h(t)−y(t)||2l(t)=f(h(t),y(t)))=||h(t)−y(t)||2,h(t),y(t)h(t),y(t)分别为输出序列与样本标签,我们要做的就是最小化整个时间序列上的l(t)l(t),即最小化
其中TT代表整个时间序列,下面我们通过LL来计算梯度,假设我们要计算dLdwdLdw,其中ww是一个标量(例如是矩阵WgxWgx的一个元素),由链式法则可以导出
其中hi(t)hi(t)是第i个单元的输出,MM是LSTM单元的个数,网络随着时间t前向传播,hi(t)hi(t)的改变不影响t时刻之前的loss,我们可以写出:
为了书写方便我们令L(t)=∑Ts=tl(s)L(t)=∑s=tTl(s)来简化我们的书写,这样L(1)L(1)就是整个序列的loss,重写上式有:
这样我们就可以将梯度重写为:
我们知道L(t)=l(t)+L(t+1)L(t)=l(t)+L(t+1),那么dL(t)dhi(t)=dl(t)dhi(t)+dL(t+1)dhi(t)dL(t)dhi(t)=dl(t)dhi(t)+dL(t+1)dhi(t),这说明得到下一时序的导数后可以直接得出当前时序的导数,所以我们可以计算TT时刻的导数然后往前推,在TT时刻有dL(T)dhi(T)=dl(T)dhi(T)dL(T)dhi(T)=dl(T)dhi(T)。
def y_list_is(self, y_list, loss_layer): """ Updates diffs by setting target sequence with corresponding loss layer. Will *NOT* update parameters. To update parameters, call self.lstm_param.apply_diff() """ assert len(y_list) == len(self.x_list) idx = len(self.x_list) - 1 # first node only gets diffs from label ... loss = loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) # here s is not affecting loss due to h(t+1), hence we set equal to zero diff_s = np.zeros(self.lstm_param.mem_cell_ct) self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) idx -= 1 ### ... following nodes also get diffs from next nodes, hence we add diffs to diff_h ### we also propagate error along constant error carousel using diff_s while idx >= 0: loss += loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) diff_h += self.lstm_node_list[idx + 1].state.bottom_diff_h diff_s = self.lstm_node_list[idx + 1].state.bottom_diff_s self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) idx -= 1 return loss
从上面公式可以很容易理解diff_h的计算过程。这里的loss_layer.bottom_diff定义如下:
def bottom_diff(self, pred, label): diff = np.zeros_like(pred) diff[0] = 2 * (pred[0] - label) return diff
该函数结合上文的loss function很明显。下面来推导dL(t)ds(t)dL(t)ds(t),结合前面的前向公式我们可以很容易得出s(t)s(t)的变化会直接影响h(t)h(t)和h(t+1)h(t+1),进而影响L(t)L(t),即有:
因为h(t+1)h(t+1)不影响l(t)l(t)所以有dL(t)dhi(t+1)=dL(t+1)dhi(t+1)dL(t)dhi(t+1)=dL(t+1)dhi(t+1),因此有:
同样的我们可以通过后面的导数逐级反推得到前面的导数,代码即diff_s的计算过程。
下面我们计算dL(t)dhi(t)∗dhi(t)dsi(t)dL(t)dhi(t)∗dhi(t)dsi(t),因为h(t)=s(t)∗o(t)h(t)=s(t)∗o(t),那么dL(t)dhi(t)∗dhi(t)dsi(t)=dL(t)dhi(t)∗oi(t)=oi(t)[diff_h]dL(t)dhi(t)∗dhi(t)dsi(t)=dL(t)dhi(t)∗oi(t)=oi(t)[diff_h],即dL(t)dsi(t)=o(t)[diff_h]i+[diff_s]idL(t)dsi(t)=o(t)[diff_h]i+[diff_s]i,其中[diff_h]i,[diff_s]i[diff_h]i,[diff_s]i分别表述当前t时序的dL(t)dhi(t)dL(t)dhi(t)和t+1时序的dL(t)dsi(t)dL(t)dsi(t)。同样的,结合上面的代码应该比较容易理解。
下面我们根据前向过程挨个计算导数: