Keras中RNN、LSTM和GRU的参数计算


1. RNN

 

 

 

RNN结构图

计算公式:

 

 

 

代码:

1 model = Sequential()
2 model.add(SimpleRNN(7, batch_input_shape=(None, 4, 2)))
3 model.summary()

运行结果:

 

 

 可见,共70个参数

记输入维度(x的维度,本例中为2)为dx, 输出维度(h的维度, 与隐藏单元数目一致,本例中为7)为dh

则公式中U的shape应该是dh*dx, W的shape因该是dh*dh, b的shape应该是dh*1

这样计算的h(t)维度才能是dh

计算公式:

nums = dh * ( dh + dx ) + dh

括号中可以理解为x和h(t-1)合并

70 = 7 *( 7 + 2 ) + 7

 

2. LSTM

https://zhuanlan.zhihu.com/p/147496732

参考这篇吧,讲的不错

 

LSTM单元结构图

 

 

 

代码:

1 model = Sequential()
2 model.add(LSTM(7, batch_input_shape=(None, 4, 2)))
3 model.summary()

运行结果:

 

 

 计算公式:

nums = 4 * [ dh * (dh + dx) + dh ]

 280 = 4 * [ 7 * (7 + 2) + 7 ]

 

3. GRU

 

GRU单元结构图

代码:

1 model = Sequential()
2 model.add(GRU(7, batch_input_shape=(None, 4, 2)))
3 model.summary()

运行结果:

 

 计算方式:

nums = 3 * [ dh * (dh + dx) + dh ]

 210 = 3 * [ 7 * (7 + 2) + 7 ]

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM