Keras中RNN、LSTM和GRU的參數計算


1. RNN

 

 

 

RNN結構圖

計算公式:

 

 

 

代碼:

1 model = Sequential()
2 model.add(SimpleRNN(7, batch_input_shape=(None, 4, 2)))
3 model.summary()

運行結果:

 

 

 可見,共70個參數

記輸入維度(x的維度,本例中為2)為dx, 輸出維度(h的維度, 與隱藏單元數目一致,本例中為7)為dh

則公式中U的shape應該是dh*dx, W的shape因該是dh*dh, b的shape應該是dh*1

這樣計算的h(t)維度才能是dh

計算公式:

nums = dh * ( dh + dx ) + dh

括號中可以理解為x和h(t-1)合並

70 = 7 *( 7 + 2 ) + 7

 

2. LSTM

https://zhuanlan.zhihu.com/p/147496732

參考這篇吧,講的不錯

 

LSTM單元結構圖

 

 

 

代碼:

1 model = Sequential()
2 model.add(LSTM(7, batch_input_shape=(None, 4, 2)))
3 model.summary()

運行結果:

 

 

 計算公式:

nums = 4 * [ dh * (dh + dx) + dh ]

 280 = 4 * [ 7 * (7 + 2) + 7 ]

 

3. GRU

 

GRU單元結構圖

代碼:

1 model = Sequential()
2 model.add(GRU(7, batch_input_shape=(None, 4, 2)))
3 model.summary()

運行結果:

 

 計算方式:

nums = 3 * [ dh * (dh + dx) + dh ]

 210 = 3 * [ 7 * (7 + 2) + 7 ]

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM