pytorch nn.LSTM()參數詳解


輸入數據格式:
input(seq_len, batch, input_size)
h0(num_layers * num_directions, batch, hidden_size)
c0(num_layers * num_directions, batch, hidden_size)

輸出數據格式:
output(seq_len, batch, hidden_size * num_directions)
hn(num_layers * num_directions, batch, hidden_size)
cn(num_layers * num_directions, batch, hidden_size)

import torch
import torch.nn as nn
from torch.autograd import Variable

#構建網絡模型---輸入矩陣特征數input_size、輸出矩陣特征數hidden_size、層數num_layers
inputs = torch.randn(5,3,10) ->(seq_len,batch_size,input_size)
rnn = nn.LSTM(10,20,2) -> (input_size,hidden_size,num_layers)
h0 = torch.randn(2,3,20) ->(num_layers* 1,batch_size,hidden_size)
c0 = torch.randn(2,3,20) ->(num_layers*1,batch_size,hidden_size)
num_directions=1 因為是單向LSTM
'''
Outputs: output, (h_n, c_n)
'''
output,(hn,cn) = rnn(inputs,(h0,c0))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
batch_first: 輸入輸出的第一維是否為 batch_size,默認值 False。因為 Torch 中,人們習慣使用Torch中帶有的dataset,dataloader向神經網絡模型連續輸入數據,這里面就有一個 batch_size 的參數,表示一次輸入多少個數據。 在 LSTM 模型中,輸入數據必須是一批數據,為了區分LSTM中的批量數據和dataloader中的批量數據是否相同意義,LSTM 模型就通過這個參數的設定來區分。 如果是相同意義的,就設置為True,如果不同意義的,設置為False。 torch.LSTM 中 batch_size 維度默認是放在第二維度,故此參數設置可以將 batch_size 放在第一維度。如:input 默認是(4,1,5),中間的 1 是 batch_size,指定batch_first=True后就是(1,4,5)。所以,如果你的輸入數據是二維數據的話,就應該將 batch_first 設置為True;

inputs = torch.randn(5,3,10) :seq_len=5,bitch_size=3,input_size=10
我的理解:有3個句子,每個句子5個單詞,每個單詞用10維的向量表示;而句子的長度是不一樣的,所以seq_len可長可短,這也是LSTM可以解決長短序列的特殊之處。只有seq_len這一參數是可變的。
關於hn和cn一些參數的詳解看這里
而在遇到文本長度不一致的情況下,將數據輸入到模型前的特征工程會將同一個batch內的文本進行padding使其長度對齊。但是對齊的數據在單向LSTM甚至雙向LSTM的時候有一個問題,LSTM會處理很多無意義的填充字符,這樣會對模型有一定的偏差,這時候就需要用到函數torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()
詳情解釋看這里

BiLSTM
BILSTM是雙向LSTM;將前向的LSTM與后向的LSTM結合成LSTM。視圖舉例如下:


​​​​​​​​​​​​LSTM結構推導:


更詳細公式推導https://blog.csdn.net/songhk0209/article/details/71134698

GRU公式推導:(網上的圖看着有點費勁,就自己畫了個數據流圖)


---------------------
作者:向陽爭渡
來源:CSDN
原文:https://blog.csdn.net/yangyang_yangqi/article/details/84585998
版權聲明:本文為博主原創文章,轉載請附上博文鏈接!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM