Keras(五)LSTM 長短期記憶模型 原理及實例


原文鏈接:http://www.one2know.cn/keras6/

  • LSTM 是 long-short term memory 的簡稱, 中文叫做 長短期記憶. 是當下最流行的 RNN 形式之一
  • RNN 的弊端
    RNN沒有長久的記憶,比如一個句子太長時開頭部分可能會忘記,從而給出錯誤的答案。
    時間遠的記憶要進過長途跋涉才能抵達最后一個時間點. 然后我們得到誤差, 而且在 反向傳遞 得到的誤差的時候, 他在每一步都會 乘以一個自己的參數 W. 如果這個 W 是一個小於1 的數, 比如0.9. 這個0.9 不斷乘以誤差, 誤差傳到初始時間點也會是一個接近於零的數, 所以對於初始時刻, 誤差相當於就消失了. 我們把這個問題叫做梯度消失或者梯度彌散 Gradient vanishing. 反之如果 W 是一個大於1 的數, 比如1.1 不斷累乘, 則到最后變成了無窮大的數, RNN被這無窮大的數撐死了, 這種情況我們叫做梯度爆炸, Gradient exploding. 這就是普通 RNN 沒有辦法回憶起久遠記憶的原因。
  • LSTM網絡

    在上圖中,每一行攜帶一個完整的向量,從一個節點的輸出到另一個節點的輸入。粉紅的圓圈代表逐點操作,如矢量加法,而黃色的方框是學習神經網絡層。行合並表示連接,而行分叉表示復制的內容以及復制到不同位置的內容。
  • 核心理念
    LSTM的關鍵是單元狀態,即貫穿圖頂部的水平線。單元狀態有點像傳送帶。它沿着整個鏈條直行,只有一些微小的線性相互作用。信息很容易保持不變地沿着它流動。

    LSTM可以去除或增加單元狀態的信息,並被稱為門(gates)的結構仔細調控,它們由一個sigmoid神經網絡層和一個逐點乘法運算組成。sigmoid輸出層的輸出介於0和1之間的數字,描述每個組件應該通過多少,0表示不讓任何東西通過,1表示可以通過。
  • 遺忘門
    遺忘門(forget gate)顧名思義,是控制是否遺忘的,在LSTM中即以一定的概率控制是否遺忘上一層的隱藏細胞狀態。遺忘門子結構如下圖:

    圖中輸入的有上一序列的隱藏狀態h(t−1)和本序列數據x(t),通過一個激活函數,一般是sigmoid,得到遺忘門的輸出f(t)。由於sigmoid的輸出f(t)在[0,1]之間,因此這里的輸出f^{(t)}代表了遺忘上一層隱藏細胞狀態的概率。用數學表達式即為:
    f(t)=σ(Wfh(t−1)+Ufx(t)+bf)f(t)=σ(Wfh(t−1)+Ufx(t)+bf)
    其中Wf,Uf,bfWf,Uf,bf為線性關系的系數和偏倚,和RNN中的類似,σ為sigmoid激活函數。
  • 輸入門
    輸入門(input gate)負責處理當前序列位置的輸入,它的子結構如下圖:

    從圖中可以看到輸入門由兩部分組成,第一部分使用了sigmoid激活函數,輸出為i(t),第二部分使用了tanh激活函數,輸出為a(t), 兩者的結果后面會相乘再去更新細胞狀態。用數學表達式即為:
    i(t)=σ(Wih(t−1)+Uix(t)+bi)i(t)=σ(Wih(t−1)+Uix(t)+bi)
    a(t)=tanh(Wah(t−1)+Uax(t)+ba)a(t)=tanh(Wah(t−1)+Uax(t)+ba)
    其中Wi,Ui,bi,Wa,Ua,ba,Wi,Ui,bi,Wa,Ua,ba,為線性關系的系數和偏倚,和RNN中的類似,σσ為sigmoid激活函數。
  • 細胞狀態更新
    在研究LSTM輸出門之前,我們要先看看LSTM之細胞狀態。前面的遺忘門和輸入門的結果都會作用於細胞狀態C(t)。我們來看看從細胞狀態C(t−1)如何得到C(t)。如下圖所示:

    細胞狀態C(t)由兩部分組成,第一部分是C(t−1)和遺忘門輸出f(t)f(t)的乘積,第二部分是輸入門的i(t)和a(t)的乘積,即:
    C(t)=C(t−1)⊙f(t)+i(t)⊙a(t)C(t)=C(t−1)⊙f(t)+i(t)⊙a(t)
    其中,⊙為Hadamard積(對應位置相乘),在DNN中也用到過。
  • 輸出門
    有了新的隱藏細胞狀態C(t),我們就可以來看輸出門了,子結構如下:

    從圖中可以看出,隱藏狀態h(t)的更新由兩部分組成,第一部分是o(t), 它由上一序列的隱藏狀態h(t−1)和本序列數據x(t),以及激活函數sigmoid得到,第二部分由隱藏狀態C(t)和tanh激活函數組成, 即:
    o(t)=σ(Woh(t−1)+Uox(t)+bo)o(t)=σ(Woh(t−1)+Uox(t)+bo)
    h(t)=o(t)⊙tanh(C(t))h(t)=o(t)⊙tanh(C(t))
    通過本節的剖析,相信大家對於LSTM的模型結構已經有了解了。當然,有些LSTM的結構和上面的LSTM圖稍有不同,但是原理是完全一樣的。
  • LSTM前向傳播算法
    LSTM模型有兩個隱藏狀態h(t),C(t),模型參數幾乎是RNN的4倍,因為現在多了Wf,Uf,bf,Wa,Ua,ba,Wi,Ui,bi,Wo,Uo,boWf,Uf,bf,Wa,Ua,ba,Wi,Ui,bi,Wo,Uo,bo這些參數。
    前向傳播過程在每個序列索引位置的過程為:
    1)更新遺忘門輸出:
    f(t)=σ(Wfh(t−1)+Ufx(t)+bf)f(t)=σ(Wfh(t−1)+Ufx(t)+bf)
    2)更新輸入門兩部分輸出:
    i(t)=σ(Wih(t−1)+Uix(t)+bi)i(t)=σ(Wih(t−1)+Uix(t)+bi)
    a(t)=tanh(Wah(t−1)+Uax(t)+ba)a(t)=tanh(Wah(t−1)+Uax(t)+ba)
    3)更新細胞狀態:
    C(t)=C(t−1)⊙f(t)+i(t)⊙a(t)C(t)=C(t−1)⊙f(t)+i(t)⊙a(t)
    4)更新輸出門輸出:
    o(t)=σ(Woh(t−1)+Uox(t)+bo)o(t)=σ(Woh(t−1)+Uox(t)+bo)
    h(t)=o(t)⊙tanh(C(t))h(t)=o(t)⊙tanh(C(t))
    5)更新當前序列索引預測輸出:
    ŷ (t)=σ(Vh(t)+c)y^(t)=σ(Vh(t)+c)
  • LSTM反向傳播算法
    有了LSTM前向傳播算法,推導反向傳播算法就很容易了, 思路和RNN的反向傳播算法思路一致,也是通過梯度下降法迭代更新我們所有的參數,關鍵點在於計算所有參數基於損失函數的偏導數。
    在RNN中,為了反向傳播誤差,我們通過隱藏狀態h(t)的梯度δ(t)一步步向前傳播。在LSTM這里也類似,只不過我們這里有兩個隱藏狀態h(t)和C(t),這里我們定義兩個δ,即:
    δ(t)h=∂L∂h(t)δh(t)=∂L∂h(t)
    δ(t)C=∂L∂C(t)δC(t)=∂L∂C(t)
    反向傳播時只使用了δ(t)CδC(t),變量δ(t)hδh(t)僅為幫助我們在某一層計算用,並沒有參與反向傳播,這里要注意。如下圖所示:

    而在最后的序列索引位置ττ的δ(τ)hδh(τ)和 δ(τ)CδC(τ)為:
    δ(τ)h=∂L∂O(τ)∂O(τ)∂h(τ)=VT(ŷ (τ)−y(τ))δh(τ)=∂L∂O(τ)∂O(τ)∂h(τ)=VT(y^(τ)−y(τ))
    δ(τ)C=∂L∂h(τ)∂h(τ)∂C(τ)=δ(τ)h⊙o(τ)⊙(1−tanh2(C(τ)))δC(τ)=∂L∂h(τ)∂h(τ)∂C(τ)=δh(τ)⊙o(τ)⊙(1−tanh2(C(τ)))
    接着我們由δ(t+1)CδC(t+1)反向推導δ(t)CδC(t)。
    δ(t)hδh(t)的梯度由本層的輸出梯度誤差決定,即:
    δ(t)h=∂L∂h(t)=VT(ŷ (t)−y(t))δh(t)=∂L∂h(t)=VT(y^(t)−y(t))
    而δ(t)CδC(t)的反向梯度誤差由前一層δ(t+1)CδC(t+1)的梯度誤差和本層的從h(t)h(t)傳回來的梯度誤差兩部分組成,即:
    δ(t)C=∂L∂C(t+1)∂C(t+1)∂C(t)+∂L∂h(t)∂h(t)∂C(t)=δ(t+1)C⊙f(t+1)+δ(t)h⊙o(t)⊙(1−tanh2(C(t)))
    δC(t)=∂L∂C(t+1)∂C(t+1)∂C(t)+∂L∂h(t)∂h(t)∂C(t)=δC(t+1)⊙f(t+1)+δh(t)⊙o(t)⊙(1−tanh2(C(t)))
    有了δ(t)hδh(t)和δ(t)CδC(t), 計算這一大堆參數的梯度就很容易了,這里只給出WfWf的梯度計算過程,其他的Uf,bf,Wa,Ua,ba,Wi,Ui,bi,Wo,Uo,bo,V,cUf,bf,Wa,Ua,ba,Wi,Ui,bi,Wo,Uo,bo,V,c的梯度大家只要照搬就可以了。
    ∂L∂Wf=∑t=1τ∂L∂C(t)∂C(t)∂f(t)∂f(t)∂Wf=∑t=1τδ(t)C⊙C(t−1)⊙f(t)⊙(1−f(t))(h(t−1))
    T∂L∂Wf=∑t=1τ∂L∂C(t)∂C(t)∂f(t)∂f(t)∂Wf=∑t=1τδC(t)⊙C(t−1)⊙f(t)⊙(1−f(t))(h(t−1))T

LSTM 實例

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# 加載數據集
dataset_train = pd.read_csv('平安銀行.csv',encoding='gb18030')
training_set = dataset_train.iloc[:,1:2].values
print(dataset_train.head()) # 查看一下數據的格式

# 特征縮放
from sklearn.preprocessing import MinMaxScaler
sc = MinMaxScaler(feature_range=(0,1))
training_set_scaled = sc.fit_transform(training_set)

# 使用Timesteps創建數據
X_train = []
y_train = []
for i in range(60, 2035):
    X_train.append(training_set_scaled[i-60:i, 0]) # 訓練集為早60個的數據
    y_train.append(training_set_scaled[i, 0])
X_train, y_train = np.array(X_train), np.array(y_train)

X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))

# 構建LSTM
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout

regressor = Sequential()

regressor.add(LSTM(units = 50, return_sequences = True, input_shape = (X_train.shape[1], 1)))
regressor.add(Dropout(0.2))

regressor.add(LSTM(units = 50, return_sequences = True))
regressor.add(Dropout(0.2))

regressor.add(LSTM(units = 50, return_sequences = True))
regressor.add(Dropout(0.2))

regressor.add(LSTM(units = 50))
regressor.add(Dropout(0.2))

regressor.add(Dense(units = 1))

regressor.compile(optimizer = 'adam', loss = 'mean_squared_error')

regressor.fit(X_train, y_train, epochs = 20, batch_size = 32)

# 預測未來的股票
dataset_test = pd.read_csv('平安銀行.csv',encoding='gb18030')
y_test = dataset_test.iloc[:, 1:2].values

dataset_total = pd.concat((dataset_train['開盤價(元)'], dataset_test['開盤價(元)']), axis = 0)
inputs = dataset_total[len(dataset_total) - len(dataset_test) - 60:].values
inputs = inputs.reshape(-1,1)
inputs = sc.transform(inputs)
X_test = []
for i in range(60, 76):
    X_test.append(inputs[i-60:i, 0])
X_test = np.array(X_test)
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
y_pred = regressor.predict(X_test)
predicted_stock_price = sc.inverse_transform(y_pred)

# 可視化
plt.plot(y_test, color = 'black', label = 'SZ000001 Price')
plt.plot(y_pred, color = 'green', label = 'Predicted SZ000001 Price')
plt.title('SZ000001 Price Prediction')
plt.xlabel('Time')
plt.ylabel('SZ000001 Price')
plt.legend()
plt.show()

輸出:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2026 CODEPRJ.COM