1. nn.LSTM
1.1 lstm=nn.LSTM(input_size, hidden_size, num_layers)
lstm=nn.LSTM(input_size, hidden_size, num_layers)
參數:
-
input_size
:輸入特征的維度, 一般rnn中輸入的是詞向量,那么 input_size 就等於一個詞向量的維度,即feature_len; -
hidden_size
:隱藏層神經元個數,或者也叫輸出的維度(因為rnn輸出為各個時間步上的隱藏狀態); -
num_layers
:網絡的層數;
1.2 out, (h_t, c_t) = lstm(x, [h_t0, c_t0])
-
x
:[seq_len, batch, feature_len] -
h/c
:[num_layer, batch, hidden_len] -
out
:[seq_len, batch, hidden_len]
import torch
from torch import nn
# 4層的LSTM,輸入的每個詞用100維向量表示,隱藏單元和記憶單元的尺寸是20
lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
# 3句話,每句10個單詞,每個單詞的詞向量維度(長度)100
x = torch.rand(10, 3, 100)
# 不傳入h_0和c_0則會默認初始化
out, (h, c) = lstm(x)
print(out.shape) # torch.Size([10, 3, 20])
print(h.shape) # torch.Size([4, 3, 20])
print(c.shape) # torch.Size([4, 3, 20])
2. nn.LSTMCell
nn.LSTMCell
與nn.LSTM
的區別 和nn.RNN
與nn.RNNCell
的區別一樣。
2.1 nn.LSTMCell()
- 初始化方法和上面一樣。
2.2 h_t, c_t = lstmcell(x_t, [h_t-1, c_t-1])
-
\(x_t\):[batch, feature_len]表示t時刻的輸入
-
\(h_{t-1}, c_{t-1}\):[batch, hidden_len],\(t-1\)時刻本層的隱藏單元和記憶單元
多層LSTM類似下圖:

import torch
from torch import nn
# 單層LSTM
# 1層的LSTM,輸入的每個詞用100維向量表示,隱藏單元和記憶單元的尺寸是20
cell = nn.LSTMCell(input_size=100, hidden_size=20)
# seq_len=10個時刻的輸入,每個時刻shape都是[batch,feature_len]
# x = [torch.randn(3, 100) for _ in range(10)]
x = torch.randn(10, 3, 100)
# 初始化隱藏單元h和記憶單元c,取batch=3
h = torch.zeros(3, 20)
c = torch.zeros(3, 20)
# 對每個時刻,傳入輸入xt和上個時刻的h和c
for xt in x:
b, c = cell(xt, (h, c))
print(b.shape) # torch.Size([3, 20])
print(c.shape) # torch.Size([3, 20])
# 兩層LSTM
# 輸入的feature_len=100,變到該層隱藏單元和記憶單元hidden_len=30
cell_L0 = nn.LSTMCell(input_size=100, hidden_size=30)
# hidden_len從L0層的30變到這一層的20
cell_L1 = nn.LSTMCell(input_size=30, hidden_size=20)
# 分別初始化L0層和L1層的隱藏單元h 和 記憶單元C,取batch=3
h_L0 = torch.zeros(3, 30)
C_L0 = torch.zeros(3, 30)
h_L1 = torch.zeros(3, 20)
C_L1 = torch.zeros(3, 20)
x = torch.randn(10, 3, 100)
for xt in x:
h_L0, C_L0 = cell_L0(xt, (h_L0, C_L0)) # L0層接受xt輸入
h_L1, C_L1 = cell_L1(h_L0, (h_L1, C_L1)) # L1層接受L0層的輸出h作為輸入
print(h_L0.shape, C_L0.shape) # torch.Size([3, 30]) torch.Size([3, 30])
print(h_L1.shape, C_L1.shape) # torch.Size([3, 20]) torch.Size([3, 20])