pytorch之 RNN 參數解釋


上次通過pytorch實現了RNN模型,簡易的完成了使用RNN完成mnist的手寫數字識別,但是里面的參數有點不了解,所以對問題進行總結歸納來解決。

 

總述:
第一次看到這個函數時,腦袋有點懵,總結了下總共有五個問題:

1.這個input_size是啥?要輸入啥?feature num又是啥?

2.這個hidden_size是啥?要輸入啥?feature num又是啥?

3.不是說RNN會有很多個節點連在一起的嗎?這怎么定義連接的節點數呢?

4.num_layer中說的stack是怎么stack的?

5.怎么輸出會有兩個東西呀output,hn

pytorch中RNN的一些參數,並且解決以上五個問題

1.Pytorch中的RNN

 

 

 


2.input_size是啥?
說白了input_size無非就是你輸入RNN的維度,比如說NLP中你需要把一個單詞輸入到RNN中,這個單詞的編碼是300維的,那么這個input_size就是300.這里的input_size其實就是規定了你的輸入變量的維度。用f(wX+b)來類比的話,這里輸入的就是X的維度。

3.hidden_size是啥?
和最簡單的BP網絡一樣的,每個RNN的節點實際上就是一個BP嘛,包含輸入層,隱含層,輸出層。這里的hidden_size呢,你可以看做是隱含層中,隱含節點的個數。

 

 

 

那個輸入層的三個節點代表輸入維度為3,也就是input_size=3,然后這個hidden_size就是5了。當然這是是對於RNN某一個節點而言的,那么如何規定RNN的節點個數呢?

4.如何規定節點個數?

事實上,節點個數並不需要規定,你的輸入序列是這樣子的,[x1,x2,x3,x4,x5],那么input_size呢就是你的xi的維度,而你的RNN的節點數呢,就是由你的序列長度決定的,在這里我們的序列長度是5,所以會有5個節點。那么問題來了,我咋知道你的序列長度呢?pytorch里面不是只有input_size的參數嗎?實際上,你聲明RNN是這樣聲明的

self.encoder = nn.RNN(input_size=300,hidden_size=128,dropout=0.5)
但是你用的時候;

output,hn = self.encoder(encoder_input,encoder_hidden)
你會把你的數據丟進去吧,也就是你把encoder_input這一整個序列丟進去了,那么序列長度他不就知道了?

5.num_layers是啥?
一開始你是不是以為這個就是RNN的節點數呀,hhh,然而並不是:),如果num_layer=2的話,表示兩個RNN堆疊在一起。那么怎么堆疊的呢?

如果是num_layer==1的話:

 

 

 

如果num_layer==2的話:

 

 

 

ok了~最后再來看看最后一個問題

6.hn,output分別是啥?

  hidden的輸出size為[ num_layers* num_directions, batch_size, n_hidden].

  說白了,hidden就是每個方向,每個層的 隱藏單元的輸出,所以是n_hidden個。

  output的size(如果RNN設定的batch_first=True),那么就是[batch_size,seq_len,n_hidden],對於分類任務如果要取得最后一個output,只需添加下標  [ :,-1,:]

看圖找答案:

 

 

 

hn就是RNN的最后一個隱含狀態,output就是RNN最終得到的結果。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM