本文是根據以下三篇文章整理的LSTM推導過程,公式都源於文章,只是一些比較概念性的東西,要coding的話還要自己去吃透以下文章。
前向傳播:
1、計算三個gate(in, out, forget)的輸入和cell的輸入:
\begin{align}
{z_{i{n_j}}}(t) = \sum\limits_m {{w_{i{n_j}m}}{y_m}(t - 1) + \sum\limits_{v = 1}^{{S_j}} {{w_{i{n_j}}}c_j^v{S_{c_j^v}}(t - 1)} } ,
\end{align}
\begin{align}
{z_{{\varphi _j}}}(t) = \sum\limits_m {{w_{{\varphi _j}m}}{y_m}(t - 1) + \sum\limits_{v = 1}^{{S_j}} {{w_{{\varphi _j}}}c_j^v{S_{c_j^v}}(t - 1)} } ,
\end{align}
\begin{align}
{z_{ou{t_j}}}(t) = \sum\limits_m {{w_{ou{t_j}m}}{y_m}(t - 1) + \sum\limits_{v = 1}^{{S_j}} {{w_{ou{t_j}}}c_j^v{S_{c_j^v}}(t - 1)} } ,
\end{align}
\begin{align}
{z_{c_j^t}}(t) = \sum\limits_m {{w_{c_j^t}}_m{y_m}(t - 1) + \sum\limits_{v = 1}^{{S_j}} {{w_{c_j^t}}c_j^v{S_{c_j^v}}(t - 1)} } ,
\end{align}
2、計算上述各個gate和cell的激活值:
\begin{align}
{y_{i{n_j}}}(t) = {f_{i{n_j}}}({z_{i{n_j}}}(t)),
\end{align}
\begin{align}
{y_{{\varphi _j}}}(t) = {f_{{\varphi _j}}}({z_{{\varphi _j}}}(t)),
\end{align}
\begin{align}
{y_{ou{t_j}}}(t) = {f_{ou{t_j}}}({z_{ou{t_j}}}(t)),
\end{align}
\begin{align}
{S_{c_j^v}}(0) = 0,{S_{c_j^v}}(t) = {y_{{\varphi _j}}}(t){S_{c_j^v}}(t - 1) + {y_{i{n_j}}}(t)g({z_{c_j^v}}(t)),
\end{align}
\begin{align}
{y_{c_j^v}}(t) = {y_{ou{t_j}}}{S_{c_j^v}}(t),
\end{align}
3、假定該網絡為一個標准的三層結構(如下圖所示),即一個輸入層,一個隱層和一個輸出層。則對於一個輸出單元,我們可以按下述的方式計算它的輸入和激活值。其中m為所有與該輸出單元連接的單元(包括輸入層的和隱層的)。
\begin{align}
{z_k}(t) = \sum\limits_m {{w_{km}}{y_m}(t),}
\end{align}
\begin{align}
{y_k}(t) = {f_k}({z_k}(t)),
\end{align}
4、計算當前時間點對應狀態對input gate和、forget gate以及cell的偏導數。這里跟CNN不一樣,CNN前向只是求值,沒有傳遞梯度。但對於lstm,由於內部狀態的改變依賴前一時間點的狀態,因此內部狀態的參數也會把錯誤傳遞到網絡下一層,因此前向也涉及到梯度傳遞。
\begin{align}
dS_{in,m}^{jv}(t) = \frac{{\partial {S_{c_j^v}}(t)}}{{\partial {w_{i{n_j}m}}}}\overset{tr}{=}\frac{{\partial {S_{c_j^v}}(t - 1)}}{{\partial {w_{i{n_j}m}}}}{y_{_{{\varphi _j}}}}(t) + g({z_{c_j^v}}(t)){{f'}_{i{n_j}}}({z_{i{n_j}}}(t)){y_m}(t - 1),
\end{align}
\begin{align}
dS_{\varphi m}^{jv}(t) = \frac{{\partial {S_{c_j^v}}(t)}}{{\partial {w_{{\varphi _j}m}}}}\overset{tr}{=}\frac{{\partial {S_{c_j^v}}(t - 1)}}{{\partial {w_{{\varphi _j}m}}}}{y_{_{{\varphi _j}}}}(t) + {S_{c_j^v}}(t - 1){{f'}_{{\varphi _j}}}({z_{{\varphi _j}}}(t)){y_m}(t - 1),
\end{align}
\begin{align}
dS_{cm}^{jv}(t) = \frac{{\partial {S_{c_j^v}}(t)}}{{\partial {w_{c_j^vm}}}}\overset{tr}{=} \frac{{\partial {S_{c_j^v}}(t - 1)}}{{\partial {w_{c_j^vm}}}}{y_{_{{\varphi _j}}}}(t) + g'({z_{c_j^v}}(t)){y_{i{n_j}}}(t){y_m}(t - 1),
\end{align}
后向傳播:
1、對於每個輸出單元(output unit),我們可以計算它的 輸出錯誤如下,其中${t_k}(t)$為前向計算的輸出,${y_k}(t)$為真實值。
\begin{align}
{e_k}(t) = {t_k}(t) - {y_k}(t),
\end{align}
2、接下來計算每個輸出單元的殘差,這里的計算和CNN是一樣的,就是對該層網絡求導。
\begin{align}
{\delta _k}(t) = {{f'}_k}({z_k}){e_k}(t)
\end{align}
3、輸出output gate的殘差計算方式和output unit類似。(output unit只針對每一個小單元的權重,而output gate針對的是所有output unit連接到輸出層的權重)
\begin{align}
\delta ou{t_j}(t) = {{f'}_{ou{t_j}}}({z_{ou{t_j}}}(t))(\sum\nolimits_{v = 1}^{{S_j}} {h({S_{c_j^v}}(t))} \sum\nolimits_k {{w_{kc_j^v}}{\delta _k}(t)} ),
\end{align}
4、第2和第3條針對的是外部殘差,內部殘差(包括input gate, forget gate和cell)計算方式如下:
\begin{align}
{e_{{S_{c_j^v}}}}(t) = {y_{ou{t_j}}}(t)h'({S_{c_j^v}}(t))(\sum\nolimits_k {{w_{kc_j^v}}{\delta _k}(t)} ),
\end{align}
5、最后,根據殘差更新各個參數(weight),注意外部和內部的表達式不一樣,具體推導見原文。
1).output unit:
\begin{align}
\Delta {w_{km}}(t) = \alpha {\delta _k}(t){y_m}(t - 1),
\end{align}
2).output gate:
\begin{align}
\Delta {w_{out,m}}(t) = \alpha {\delta _{out}}(t){y_m}(t - 1),
\end{align}
3).input gate:
\begin{align}
\Delta {w_{in,m}}(t) = \alpha \sum\nolimits_{v = 1}^{{S_j}} {{e_{{S_{c_j^v}}}}(t)dS_{in,m}^{jv}(t)} ,
\end{align}
4).forget gate:
\begin{align}
\Delta {w_{\varphi m}}(t) = \alpha \sum\nolimits_{v = 1}^{{S_j}} {{e_{{S_{c_j^v}}}}(t)dS_{\varphi m}^{jv}(t)} ,
\end{align}
5).cell:
\begin{align}
\Delta {w_{c_j^vm}}(t) = \alpha {e_{{S_{c_j^v}}}}(t)dS_{cm}^{jv}(t),
\end{align}