1. RNN
RNN結構圖
計算公式:
代碼:
1 model = Sequential() 2 model.add(SimpleRNN(7, batch_input_shape=(None, 4, 2))) 3 model.summary()
運行結果:
可見,共70個參數
記輸入維度(x的維度,本例中為2)為dx, 輸出維度(h的維度, 與隱藏單元數目一致,本例中為7)為dh
則公式中U的shape應該是dh*dx, W的shape因該是dh*dh, b的shape應該是dh*1
這樣計算的h(t)維度才能是dh
計算公式:
nums = dh * ( dh + dx ) + dh
括號中可以理解為x和h(t-1)合並
70 = 7 *( 7 + 2 ) + 7
2. LSTM
https://zhuanlan.zhihu.com/p/147496732
參考這篇吧,講的不錯
LSTM單元結構圖
代碼:
1 model = Sequential() 2 model.add(LSTM(7, batch_input_shape=(None, 4, 2))) 3 model.summary()
運行結果:
計算公式:
nums = 4 * [ dh * (dh + dx) + dh ]
280 = 4 * [ 7 * (7 + 2) + 7 ]
3. GRU
GRU單元結構圖
代碼:
1 model = Sequential() 2 model.add(GRU(7, batch_input_shape=(None, 4, 2))) 3 model.summary()
運行結果:
計算方式:
nums = 3 * [ dh * (dh + dx) + dh ]
210 = 3 * [ 7 * (7 + 2) + 7 ]