1.nn.LSTM
1.1lstm=nn.LSTM(input_size, hidden_size, num_layers)
參數:
- input_size:輸入特征的維度, 一般rnn中輸入的是詞向量,那么 input_size 就等於一個詞向量的維度,即feature_len;
- hidden_size:隱藏層神經元個數,或者也叫輸出的維度(因為rnn輸出為各個時間步上的隱藏狀態);
- num_layers:網絡的層數;
1.2out, (ht, ct) = lstm(x, [ht0, ct0])
- x:[seq_len, batch, feature_len]
- h/c:[num_layer, batch, hidden_len]
- out:[seq_len, batch, hidden_len]
1 import torch 2 from torch import nn 3 4 lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4) #4層的LSTM,輸入的每個詞用100維向量表示,隱藏單元和記憶單元的尺寸是20 5 6 x = torch.randn(10, 3, 100) #3句話,每句10個單詞,每個單詞表示為長100的向量 7 out, (h, c) = lstm(x) #不傳入h_0和c_0則會默認初始化 8 print(out.shape) #torch.Size([10, 3, 20]) 9 print(h.shape) #torch.Size([4, 3, 20]) 10 print(c.shape) #torch.Size([4, 3, 20])
2.nn.LSTMCell
nn.LSTMCell與nn.LSTM的區別和nn.RNN與nn.RNNCell的區別一樣。
2.1nn.LSTMCell()
初始化方法和上面一樣。
2.2ht, ct = lstmcell(xt, [ht-1, ct-1])
- xt:[batch, feature_len]表示t時刻的輸入
- ht-1, ct-1:[batch, hidden_len] t-1時刻本層的隱藏單元和記憶單元
1 #單層LSTM 2 import torch 3 from torch import nn 4 5 cell = nn.LSTMCell(input_size=100, hidden_size=20) #1層的LSTM,輸入的每個詞用100維向量表示,隱藏單元和記憶單元的尺寸是20 6 7 h = torch.zeros(3, 20) #初始化隱藏單元h和記憶單元c,取batch=3 8 c = torch.zeros(3, 20) 9 10 x = [torch.randn(3, 100) for _ in range(10)] #seq_len=10個時刻的輸入,每個時刻shape都是[batch,feature_len] 11 12 for xt in x: #對每個時刻,傳入輸入xt和上個時刻的h和c 13 h, c = cell(xt, (h, c)) 14 15 print(h.shape,c.shape) #torch.Size([3, 20]) torch.Size([3, 20]) 16 17 18 #兩層LSTM 19 cell_l0 = nn.LSTMCell(input_size=100, hidden_size=30) #輸入的feature_len=100,變到該層隱藏單元和記憶單元hidden_len=30 20 cell_l1 = nn.LSTMCell(input_size=30, hidden_size=20) #hidden_len從l0層的30變到這一層的20 21 22 h_l0 = torch.zeros(3, 30) #分別初始化l0層和l1層的隱藏單元h和記憶單元C,取batch=3 23 C_l0 = torch.zeros(3, 30) 24 25 h_l1 = torch.zeros(3, 20) 26 C_l1 = torch.zeros(3, 20) 27 28 x = [torch.randn(3, 100) for _ in range(10)] #seq_len=10個時刻的輸入,每個時刻shape都是[batch,feature_len] 29 30 for xt in x: 31 h_l0, C_l0 = cell_l0(xt, (h_l0, C_l0)) #l0層接受xt輸入 32 h_l1, C_l1 = cell_l1(h_l0, (h_l1, C_l1)) #l1層接受l0層的輸出h作為輸入 33 34 print(h_l0.shape, C_l0.shape) #torch.Size([3, 30]) torch.Size([3, 30]) 35 print(h_l1.shape, C_l1.shape) #torch.Size([3, 20]) torch.Size([3, 20])