[PyTorch] rnn,lstm,gru中輸入輸出維度


本文中的RNN泛指LSTM,GRU等等
CNN中和RNNbatchSize的默認位置是不同的。

  • CNN中:batchsize的位置是position 0.
  • RNN中:batchsize的位置是position 1.

在RNN中輸入數據格式:

對於最簡單的RNN,我們可以使用兩種方式來調用,torch.nn.RNNCell(),它只接受序列中的單步輸入,必須顯式的傳入隱藏狀態torch.nn.RNN()可以接受一個序列的輸入,默認會傳入一個全0的隱藏狀態,也可以自己申明隱藏狀態傳入。

  1. 輸入大小是三維tensor[seq_len,batch_size,input_dim]
  • input_dim是輸入的維度,比如是128
  • batch_size是一次往RNN輸入句子的數目,比如是5
  • seq_len是一個句子的最大長度,比如15
    所以千萬注意,RNN輸入的是序列,一次把批次的所有句子都輸入了,得到的ouptuthidden都是這個批次的所有的輸出和隱藏狀態,維度也是三維。
    **可以理解為現在一共有batch_size個獨立的RNN組件,RNN的輸入維度是input_dim,總共輸入seq_len個時間步,則每個時間步輸入到這個整個RNN模塊的維度是[batch_size,input_dim]
# 構造RNN網絡,x的維度5,隱層的維度10,網絡的層數2 rnn_seq = nn.RNN(5, 10,2) # 構造一個輸入序列,句長為 6,batch 是 3, 每個單詞使用長度是 5的向量表示 x = torch.randn(6, 3, 5) #out,ht = rnn_seq(x,h0) out,ht = rnn_seq(x) #h0可以指定或者不指定 

問題1:這里outhtsize是多少呢?
回答out:6 * 3 * 10, ht: 2 * 3 * 10,out的輸出維度[seq_len,batch_size,output_dim],ht的維度[num_layers * num_directions, batch, hidden_size],如果是單向單層的RNN那么一個句子只有一個hidden
問題2out[-1]ht[-1]是否相等?
回答:相等,隱藏單元就是輸出的最后一個單元,可以想象,每個的輸出其實就是那個時間步的隱藏單元

  1. RNN的其他參數
RNN(input_dim ,hidden_dim ,num_layers ,…)
– input_dim 表示輸入的特征維度
– hidden_dim 表示輸出的特征維度,如果沒有特殊變化,相當於out
– num_layers 表示網絡的層數
– nonlinearity 表示選用的非線性激活函數,默認是 ‘tanh’
– bias 表示是否使用偏置,默認使用
– batch_first 表示輸入數據的形式,默認是 False,就是這樣形式,(seq, batch, feature),也就是將序列長度放在第一位,batch 放在第二位 – dropout 表示是否在輸出層應用 dropout – bidirectional 表示是否使用雙向的 rnn,默認是 False 
 
向RNN中輸入的tensor的形狀

LSTM的輸出多了一個memory單元

# 輸入維度 50,隱層100維,兩層 lstm_seq = nn.LSTM(50, 100, num_layers=2) # 輸入序列seq= 10,batch =3,輸入維度=50 lstm_input = torch.randn(10, 3, 50) out, (h, c) = lstm_seq(lstm_input) # 使用默認的全 0 隱藏狀態 

問題1out(h,c)的size各是多少?
回答out:(10 * 3 * 100),(h,c):都是(2 * 3 * 100)
問題2out[-1,:,:]h[-1,:,:]相等嗎?
回答: 相等

GRU比較像傳統的RNN

gru_seq = nn.GRU(10, 20,2) # x_dim,h_dim,layer_num gru_input = torch.randn(3, 32, 10) # seq,batch,x_dim out, h = gru_seq(gru_input)


作者:VanJordan
鏈接:https://www.jianshu.com/p/b942e65cb0a3
來源:簡書
簡書著作權歸作者所有,任何形式的轉載都請聯系作者獲得授權並注明出處。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM