背景:
RNN(Recurrent Neural Networks),被國內翻譯為循環神經網絡,或者遞歸神經網絡,竊以為這兩種表述都不合理,應該稱為:(深度)同參時序神經網絡(下文展開講述)。
RNN公式(來自:pytorch rnn):
\begin{align} h_t &=tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{(t-1)}+b_{hh}) \end{align}
這個公式體現了每層RNN的輸入(input gate)計算和隱藏狀態(hidden state)的計算過程,這是RNN每一層的計算公式,其中
$W_{ih}$ $W_{hh}$ 分別代表該層 計算input gate的weight參數和計算hidden state的weight參數
$b_{ih}$ $b_{hh}$代表該層對應的bias,
$h_{(t-1)}$代表上一個timestep的hidden state (對於時序序列sequence中的第一個樣本,$h_{(t-1)}$ 即$h_0$可隨機初始化生成)
需要注意:
1. $x_t$,對於第一層layer,$x_t$自然是訓練樣本,那第二層的${x_t}$是什么,還是訓練樣本? 帶着這個疑問,看下文
2.每層RNN的計算過程都是公式(1),每層RNN都有且只有參數:$W_{ih}$ $W_{hh}$ $b_{ih}$ $b_{hh}$ 即每層四個參數變量(無論序列長度sequence length是多少、無論一個樣本的維度dimension是多少),$x_t$和$h_{(t-1)}$是中間的計算結果,不是參數。即,一層的RNN,有4個參數變量,兩層的RNN,有8個參數變量,N層的RNN有4N個參數變量。當然,每個參數變量都是參數矩陣Matrix。 知道模型的參數數量,模型的結構就基本能清楚了。所以這個注意點非常重要。
模型結構:
網上有很多對RNN模型結構的解釋圖。
有這樣的:

不好意思,上傳錯了。
有這樣的,圖1

有這樣的,圖2

有這樣的,圖3

也有這樣的,圖4
竊以為能看得更明白的,是下圖5 來自知乎 Scofield

下面這個動畫,是 維度為3的樣本,多層RNN的計算過程。竊以為,這個視頻最受啟發。感謝 知乎劉大力 動畫來自知乎文章(若動畫看不了,請移步至該鏈接)
以上圖片和視頻,動畫說明了sequence中各t時刻的一個樣本輸入RNN模型中的計算流圖,但竊並不認可其對於模型輸出Y的體現。
一個訓練樣本,有t個時刻timestep,每個時刻的樣本有多個維度 dimension。動畫中,t = 5,dimension = 3 。
竊以為,一層RNN模型, 共四個參數變量,$W_{ih}$ $W_{hh}$ $b_{ih}$ $b_{hh}$,不同時刻的訓練樣本同參(同參數變量,變量的值可以變化);序列中的樣本,按照先后順序與這四個參數變量計算,hidden state $h_{(t-1)}$來自上一時刻的計算結果,即時序依賴。 所以,一層RNN應該被稱為同參時序神經網絡,多層RNN應該叫:深度同參時序神經網絡。這樣稱呼,雖然不利於傳播,但易於理解。
兩層的RNN模型,則是第一個時刻的訓練樣本先與第一層的4個參數按照公式(1)計算,計算出的hidden state $h_{(t)_{l0}}$值暫存,並傳給 第二層公式(1)中的$x_t$,計算出的hidden state $h_{(t)_{l1}}$ 暫存。暫存的兩個 $h_{(t)_{l0}}$和$h_{(t)_{l1}}$ 值,參與第二個時刻的計算(變為$h_{(t-1)_{l0}}$和$h_{(t-1)_{l1}}$),即分別對應第一層公式(1)和第二層的公式(1)中的$h_{(t-1)}$ 。多層RNN,以此類推。由此可見,RNN有多少層,就會有多個hidden state $h_{(t)_{lx}}$,pytorch rnn文檔中輸出的h_n的shape體現了這一點 。 需要注意的點是:$h_{(t)_{l}}$的值,有兩個用到的地方,下一層的$x_t$ 和下一時刻的$h_{(t-1)_{l}}$ 。這一點,在網上很多RNN結構圖中體現得不清楚。
看到這里,再來看下面這張結構圖7(不同時刻同參、每個hidden state有兩個去處):

有沒有更理解一點?
代碼驗證,手寫RNN與pytorch 官方RNN對比結果(一層RNN) :
import torch from torch import nn #network parameters input_size = 10 hidden_size = 20 num_layers = 1 #fixed, can't change. Only one layer for this demo. #data parameters seq_len = 5 batch_size = 3 data_dim = input_size #input data data = torch.randn(seq_len, batch_size, data_dim) #official rnn in pytorch ornn = nn.RNN(input_size, hidden_size, num_layers) #init hidden state h0 = torch.randn(num_layers,batch_size,hidden_size) #rnn implemented by myself class MyRNN(): def __init__(self): #keep weights and bias parameters the same with official rnn # to make the compare with official rnn by final result self.W_ih = torch.nn.Parameter(ornn.weight_ih_l0.T) self.b_ih = torch.nn.Parameter(ornn.bias_ih_l0) self.W_hh = torch.nn.Parameter(ornn.weight_hh_l0.T) self.b_hh = torch.nn.Parameter(ornn.bias_hh_l0) self.ht = torch.nn.Parameter(h0) self.myoutput = [] def forward(self,x): #x shape: (seq_len,batch_size,data_dim) for i in range(seq_len): #this line is the KEY to understand RNN. Important! igates = torch.matmul(x[i],self.W_ih) + self.b_ih hgates = torch.matmul(self.ht,self.W_hh) + self.b_hh self.ht = torch.tanh(igates + hgates)#this line is the formula of RNN. Important! self.myoutput.append(self.ht) return self.ht,self.myoutput myrnn = MyRNN() myht,myoutput = myrnn.forward(data) official_output,official_hn = ornn(data,h0) print ('myht:') print (myht) print ('official_hn:') print (official_hn) print ("--" * 40) print ('myoutput:') print (myoutput) print ('official_output:') print (official_output)
輸出結果:
myht: tensor([[[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713, 0.3459, -0.0215, 0.1415, -0.3584, 0.0867, 0.5099, 0.5690, -0.2088, 0.2408, 0.5786, 0.2619, 0.5168, 0.3505], [-0.1239, -0.6814, -0.0618, -0.2011, 0.2991, 0.2501, -0.4828, -0.2206, 0.4686, 0.8130, -0.2484, 0.4848, -0.4787, -0.1967, -0.2221, 0.5839, 0.1737, -0.4704, -0.0062, 0.3153], [ 0.2130, 0.4595, 0.5993, 0.5096, -0.0535, -0.2761, 0.1230, 0.4049, -0.3259, 0.3720, 0.2531, -0.5754, 0.0438, -0.1637, -0.0579, 0.1214, 0.7484, -0.8096, -0.2217, 0.1534]]], grad_fn=<TanhBackward>) official_hn: tensor([[[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713, 0.3459, -0.0215, 0.1415, -0.3584, 0.0867, 0.5099, 0.5690, -0.2088, 0.2408, 0.5786, 0.2619, 0.5168, 0.3505], [-0.1239, -0.6814, -0.0618, -0.2011, 0.2991, 0.2501, -0.4828, -0.2206, 0.4686, 0.8130, -0.2484, 0.4848, -0.4787, -0.1967, -0.2221, 0.5839, 0.1737, -0.4704, -0.0062, 0.3153], [ 0.2130, 0.4595, 0.5993, 0.5096, -0.0535, -0.2761, 0.1230, 0.4049, -0.3259, 0.3720, 0.2531, -0.5754, 0.0438, -0.1637, -0.0579, 0.1214, 0.7484, -0.8096, -0.2217, 0.1534]]], grad_fn=<StackBackward>) -------------------------------------------------------------------------------- myoutput: [tensor([[[ 0.1838, -0.5729, 0.7425, -0.1386, 0.4525, -0.0928, 0.4676, 0.1947, -0.2111, -0.2790, -0.3584, 0.1215, -0.5577, 0.3709, 0.9216, 0.0695, 0.0420, -0.5991, -0.8501, 0.4155], [-0.0024, -0.5132, -0.6784, 0.7312, -0.1101, -0.4194, 0.1185, 0.4437, -0.5395, 0.8785, -0.6332, -0.5439, -0.4265, 0.1511, -0.0327, 0.4625, -0.4097, -0.9240, -0.6085, 0.3099], [ 0.1994, 0.6158, 0.9422, 0.8493, -0.6427, 0.0086, 0.0350, 0.1801, -0.8858, 0.4427, -0.2625, 0.7059, -0.4321, 0.5412, 0.5879, 0.5385, -0.2290, -0.8183, -0.4205, -0.7687]]], grad_fn=<TanhBackward>), tensor([[[ 0.3837, -0.0271, 0.1710, 0.5887, -0.1873, -0.0959, 0.3320, 0.0613, 0.3565, -0.7313, -0.2641, -0.8821, 0.7630, 0.2369, 0.5095, -0.7738, 0.0350, 0.1001, 0.4966, 0.4144], [-0.1493, -0.3873, 0.6141, 0.1870, 0.0825, -0.0518, 0.0583, 0.3066, 0.6362, 0.1345, -0.2821, 0.0061, -0.3376, -0.2284, 0.1351, 0.3951, 0.0039, -0.6607, -0.1473, 0.6156], [ 0.8971, -0.1361, 0.0733, 0.5407, -0.5882, -0.4531, 0.2926, 0.5090, 0.4893, -0.2589, 0.1735, -0.1201, -0.0110, -0.4264, 0.3931, 0.0637, 0.5885, 0.4706, 0.1418, 0.3165]]], grad_fn=<TanhBackward>), tensor([[[ 0.3517, -0.7295, -0.0883, -0.6818, 0.3883, 0.3556, -0.1627, -0.1085, 0.6256, 0.8205, -0.6915, 0.5160, -0.0390, 0.3519, -0.0271, 0.0300, 0.0965, -0.3939, -0.0956, 0.2624], [ 0.5152, 0.0578, 0.4200, -0.6778, -0.6455, -0.1427, -0.2189, 0.1818, -0.1449, 0.1035, -0.6252, 0.7734, -0.5083, 0.6138, 0.4270, 0.5684, 0.6656, 0.5341, -0.0336, 0.6554], [-0.2308, 0.4569, 0.2901, -0.1212, -0.4826, -0.2699, 0.2559, -0.3331, -0.0299, 0.0830, 0.2832, -0.5203, -0.0953, -0.3784, -0.1478, -0.1610, -0.3416, -0.7735, 0.4389, 0.4663]]], grad_fn=<TanhBackward>), tensor([[[-0.4632, -0.7146, -0.0497, -0.4927, 0.0778, 0.6394, 0.0383, -0.6022, 0.4774, -0.0682, -0.1731, -0.5328, -0.2757, 0.1885, 0.6235, -0.0990, -0.3720, 0.0275, 0.4964, 0.7343], [ 0.7086, -0.7316, -0.7619, 0.4543, -0.0888, 0.5574, 0.1033, 0.4042, 0.4909, 0.2489, -0.6275, -0.9121, 0.4050, 0.5086, 0.1161, -0.3312, -0.0297, 0.0204, 0.0442, 0.5536], [-0.0543, 0.0078, -0.8657, 0.6617, -0.2335, -0.0423, 0.2600, 0.1319, -0.2510, 0.3286, 0.1862, -0.6161, -0.1817, 0.4460, -0.6628, -0.1969, 0.5526, -0.8781, -0.4859, 0.6430]]], grad_fn=<TanhBackward>), tensor([[[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713, 0.3459, -0.0215, 0.1415, -0.3584, 0.0867, 0.5099, 0.5690, -0.2088, 0.2408, 0.5786, 0.2619, 0.5168, 0.3505], [-0.1239, -0.6814, -0.0618, -0.2011, 0.2991, 0.2501, -0.4828, -0.2206, 0.4686, 0.8130, -0.2484, 0.4848, -0.4787, -0.1967, -0.2221, 0.5839, 0.1737, -0.4704, -0.0062, 0.3153], [ 0.2130, 0.4595, 0.5993, 0.5096, -0.0535, -0.2761, 0.1230, 0.4049, -0.3259, 0.3720, 0.2531, -0.5754, 0.0438, -0.1637, -0.0579, 0.1214, 0.7484, -0.8096, -0.2217, 0.1534]]], grad_fn=<TanhBackward>)] official_output: tensor([[[ 0.1838, -0.5729, 0.7425, -0.1386, 0.4525, -0.0928, 0.4676, 0.1947, -0.2111, -0.2790, -0.3584, 0.1215, -0.5577, 0.3709, 0.9216, 0.0695, 0.0420, -0.5991, -0.8501, 0.4155], [-0.0024, -0.5132, -0.6784, 0.7312, -0.1101, -0.4194, 0.1185, 0.4437, -0.5395, 0.8785, -0.6332, -0.5439, -0.4265, 0.1511, -0.0327, 0.4625, -0.4097, -0.9240, -0.6085, 0.3099], [ 0.1994, 0.6158, 0.9422, 0.8493, -0.6427, 0.0086, 0.0350, 0.1801, -0.8858, 0.4427, -0.2625, 0.7059, -0.4321, 0.5412, 0.5879, 0.5385, -0.2290, -0.8183, -0.4205, -0.7687]], [[ 0.3837, -0.0271, 0.1710, 0.5887, -0.1873, -0.0959, 0.3320, 0.0613, 0.3565, -0.7313, -0.2641, -0.8821, 0.7630, 0.2369, 0.5095, -0.7738, 0.0350, 0.1001, 0.4966, 0.4144], [-0.1493, -0.3873, 0.6141, 0.1870, 0.0825, -0.0518, 0.0583, 0.3066, 0.6362, 0.1345, -0.2821, 0.0061, -0.3376, -0.2284, 0.1351, 0.3951, 0.0039, -0.6607, -0.1473, 0.6156], [ 0.8971, -0.1361, 0.0733, 0.5407, -0.5882, -0.4531, 0.2926, 0.5090, 0.4893, -0.2589, 0.1735, -0.1201, -0.0110, -0.4264, 0.3931, 0.0637, 0.5885, 0.4706, 0.1418, 0.3165]], [[ 0.3517, -0.7295, -0.0883, -0.6818, 0.3883, 0.3556, -0.1627, -0.1085, 0.6256, 0.8205, -0.6915, 0.5160, -0.0390, 0.3519, -0.0271, 0.0300, 0.0965, -0.3939, -0.0956, 0.2624], [ 0.5152, 0.0578, 0.4200, -0.6778, -0.6455, -0.1427, -0.2189, 0.1818, -0.1449, 0.1035, -0.6252, 0.7734, -0.5083, 0.6138, 0.4270, 0.5684, 0.6656, 0.5341, -0.0336, 0.6554], [-0.2308, 0.4569, 0.2901, -0.1212, -0.4826, -0.2699, 0.2559, -0.3331, -0.0299, 0.0830, 0.2832, -0.5203, -0.0953, -0.3784, -0.1478, -0.1610, -0.3416, -0.7735, 0.4389, 0.4663]], [[-0.4632, -0.7146, -0.0497, -0.4927, 0.0778, 0.6394, 0.0383, -0.6022, 0.4774, -0.0682, -0.1731, -0.5328, -0.2757, 0.1885, 0.6235, -0.0990, -0.3720, 0.0275, 0.4964, 0.7343], [ 0.7086, -0.7316, -0.7619, 0.4543, -0.0888, 0.5574, 0.1033, 0.4042, 0.4909, 0.2489, -0.6275, -0.9121, 0.4050, 0.5086, 0.1161, -0.3312, -0.0297, 0.0204, 0.0442, 0.5536], [-0.0543, 0.0078, -0.8657, 0.6617, -0.2335, -0.0423, 0.2600, 0.1319, -0.2510, 0.3286, 0.1862, -0.6161, -0.1817, 0.4460, -0.6628, -0.1969, 0.5526, -0.8781, -0.4859, 0.6430]], [[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713, 0.3459, -0.0215, 0.1415, -0.3584, 0.0867, 0.5099, 0.5690, -0.2088, 0.2408, 0.5786, 0.2619, 0.5168, 0.3505], [-0.1239, -0.6814, -0.0618, -0.2011, 0.2991, 0.2501, -0.4828, -0.2206, 0.4686, 0.8130, -0.2484, 0.4848, -0.4787, -0.1967, -0.2221, 0.5839, 0.1737, -0.4704, -0.0062, 0.3153], [ 0.2130, 0.4595, 0.5993, 0.5096, -0.0535, -0.2761, 0.1230, 0.4049, -0.3259, 0.3720, 0.2531, -0.5754, 0.0438, -0.1637, -0.0579, 0.1214, 0.7484, -0.8096, -0.2217, 0.1534]]], grad_fn=<StackBackward>)
兩者一致,說明模型結構正確。
代碼驗證,手寫RNN與pytorch 官方RNN對比結果(兩層RNN):
import torch from torch import nn #network parameters input_size = 10 hidden_size = 20 num_layers = 2 #data parameters seq_len = 5 batch_size = 3 data_dim = input_size data = torch.randn(seq_len, batch_size, data_dim) #original official rnn in pytorch ornn = nn.RNN(input_size, hidden_size, num_layers) h0 = torch.randn(num_layers,batch_size,hidden_size) class MyRNN(): def __init__(self): #input_size, hidden_size self.W_ih = torch.nn.Parameter(ornn.weight_ih_l0.T) self.b_ih = torch.nn.Parameter(ornn.bias_ih_l0) self.W_hh = torch.nn.Parameter(ornn.weight_hh_l0.T) self.b_hh = torch.nn.Parameter(ornn.bias_hh_l0) self.ht = torch.nn.Parameter(h0) self.myoutput = [] if num_layers == 2: self.ht = torch.nn.Parameter(h0[0]) self.ht1 = torch.nn.Parameter(h0[1]) self.W_ih_l1 = torch.nn.Parameter(ornn.weight_ih_l1.T) self.b_ih_l1 = torch.nn.Parameter(ornn.bias_ih_l1) self.W_hh_l1 = torch.nn.Parameter(ornn.weight_hh_l1.T) self.b_hh_l1 = torch.nn.Parameter(ornn.bias_hh_l1) def forward(self,x): #x: (seq_len,batch_size,data_dim) for i in range(seq_len): #the first layer. apply the formula igates = torch.matmul(x[i],self.W_ih) + self.b_ih hgates = torch.matmul(self.ht,self.W_hh) + self.b_hh #ht read from the early timestep. self.ht = torch.tanh(igates + hgates) #ht update if num_layers == 2: #the second layer. apply the formula igates = torch.matmul(self.ht,self.W_ih_l1) + self.b_ih_l1 #ht read from the the first layer. important! hgates = torch.matmul(self.ht1,self.W_hh_l1) + self.b_hh_l1 #ht1 read from the early timestep. self.ht1 = torch.tanh(igates + hgates) #ht1 update ht_final_layer = [self.ht,self.ht1] self.myoutput.append(self.ht1) #important. just ht1 ,the output of last layer. return ht_final_layer,self.myoutput myrnn = MyRNN() myht,myoutput = myrnn.forward(data) official_output,official_hn = ornn(data,h0) print ('myht:') print (myht) print ('official_hn:') print (official_hn) print ("--" * 40) print ('myoutput:') print (myoutput) print ('official_output')
執行結果:
myht: [tensor([[-0.0386, 0.0588, 0.3025, -0.6304, 0.2505, -0.2632, 0.0101, -0.6417, 0.2560, -0.1788, 0.3951, -0.3890, 0.5895, 0.1630, 0.1462, -0.6854, -0.1802, -0.3126, -0.8059, -0.1910], [-0.3681, 0.2041, 0.2560, 0.6034, 0.1888, 0.0478, 0.4822, 0.0652, -0.7043, -0.2169, 0.2462, 0.1334, -0.1881, 0.4579, -0.0285, 0.1425, 0.3664, 0.4980, 0.2442, -0.5373], [-0.5242, -0.0747, -0.4040, -0.0835, 0.6314, 0.1566, 0.2049, -0.1784, -0.2990, -0.3908, -0.2911, -0.2110, 0.6358, 0.4597, 0.0701, -0.3386, 0.5218, -0.5246, -0.3237, 0.0551]], grad_fn=<TanhBackward>), tensor([[-8.1422e-02, 1.8842e-02, 1.0144e-01, 3.8074e-02, -5.5781e-01, 3.0895e-01, -4.5643e-01, 1.6013e-01, -2.9781e-01, 1.9961e-01, -1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01, 3.6203e-01, 5.9737e-02, 2.3827e-01, -1.1631e-01, -7.2619e-02, -1.0607e-01], [-5.0548e-05, -1.7537e-01, -5.8831e-02, 1.4870e-01, -6.3641e-01, -8.8198e-02, 3.6017e-01, -7.5253e-02, -1.5342e-01, -1.3452e-01, -1.0299e-01, 5.4613e-02, 2.3648e-01, -4.3949e-01, 3.1918e-02, -1.2566e-01, 5.4776e-01, 1.7124e-01, 4.7584e-01, -1.5331e-01], [-2.6558e-01, 2.2919e-01, 5.0820e-01, 1.0909e-01, -4.1980e-01, -1.5907e-01, 2.2342e-01, 5.8021e-02, -2.7018e-01, -3.4850e-01, -1.8512e-01, 6.1239e-03, 2.4516e-01, -5.7344e-01, -7.5789e-02, 1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]], grad_fn=<TanhBackward>)] official_hn: tensor([[[-3.8583e-02, 5.8803e-02, 3.0251e-01, -6.3039e-01, 2.5051e-01, -2.6322e-01, 1.0055e-02, -6.4175e-01, 2.5604e-01, -1.7878e-01, 3.9513e-01, -3.8902e-01, 5.8954e-01, 1.6296e-01, 1.4621e-01, -6.8542e-01, -1.8024e-01, -3.1264e-01, -8.0587e-01, -1.9096e-01], [-3.6807e-01, 2.0406e-01, 2.5604e-01, 6.0344e-01, 1.8878e-01, 4.7830e-02, 4.8223e-01, 6.5184e-02, -7.0430e-01, -2.1692e-01, 2.4618e-01, 1.3339e-01, -1.8806e-01, 4.5792e-01, -2.8516e-02, 1.4252e-01, 3.6637e-01, 4.9800e-01, 2.4424e-01, -5.3730e-01], [-5.2422e-01, -7.4715e-02, -4.0400e-01, -8.3507e-02, 6.3144e-01, 1.5658e-01, 2.0493e-01, -1.7839e-01, -2.9904e-01, -3.9076e-01, -2.9111e-01, -2.1097e-01, 6.3583e-01, 4.5969e-01, 7.0081e-02, -3.3865e-01, 5.2179e-01, -5.2456e-01, -3.2368e-01, 5.5066e-02]], [[-8.1422e-02, 1.8842e-02, 1.0144e-01, 3.8074e-02, -5.5781e-01, 3.0895e-01, -4.5643e-01, 1.6013e-01, -2.9781e-01, 1.9961e-01, -1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01, 3.6203e-01, 5.9737e-02, 2.3827e-01, -1.1631e-01, -7.2618e-02, -1.0607e-01], [-5.0455e-05, -1.7537e-01, -5.8831e-02, 1.4870e-01, -6.3641e-01, -8.8198e-02, 3.6017e-01, -7.5252e-02, -1.5342e-01, -1.3452e-01, -1.0299e-01, 5.4613e-02, 2.3648e-01, -4.3949e-01, 3.1918e-02, -1.2566e-01, 5.4776e-01, 1.7124e-01, 4.7584e-01, -1.5331e-01], [-2.6558e-01, 2.2919e-01, 5.0820e-01, 1.0909e-01, -4.1980e-01, -1.5907e-01, 2.2342e-01, 5.8021e-02, -2.7018e-01, -3.4850e-01, -1.8512e-01, 6.1239e-03, 2.4516e-01, -5.7344e-01, -7.5789e-02, 1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]]], grad_fn=<StackBackward>) -------------------------------------------------------------------------------- myoutput: [tensor([[ 1.2620e-01, 6.0072e-01, -7.1112e-02, 2.0916e-01, -1.7033e-01, -5.8128e-02, 3.4290e-01, 5.7120e-01, 7.6652e-04, 4.7431e-01, -9.7752e-02, -6.9819e-01, 4.4204e-02, 1.8705e-01, 3.7682e-01, -3.2877e-01, 3.3991e-01, -9.3203e-01, -4.5387e-01, -7.6271e-01], [ 6.4547e-01, -3.3936e-01, -6.3192e-01, -3.2661e-02, -5.8965e-01, -7.6409e-01, 1.3470e-01, -3.7835e-01, -2.0378e-01, -1.6322e-01, -6.0952e-01, 2.9986e-02, 7.8969e-02, -6.4902e-01, 7.5271e-01, 1.7919e-01, 6.5517e-01, -6.5625e-01, 3.2050e-01, -3.4623e-01], [-5.5407e-01, 2.0340e-01, 4.1821e-01, -9.7931e-02, 4.2492e-01, 8.5182e-01, 7.9682e-02, 7.5144e-01, -2.0973e-01, -1.3963e-01, 4.5111e-01, -5.1502e-01, -3.1101e-01, 8.7050e-02, 7.7077e-01, 4.9754e-01, -1.6914e-01, 5.5128e-01, -7.0215e-01, 2.6817e-01]], grad_fn=<TanhBackward>), tensor([[-0.2918, 0.5013, 0.0336, -0.3569, -0.5727, -0.1577, -0.1704, 0.3353, -0.4692, -0.2399, 0.3714, 0.3964, -0.2294, 0.0909, 0.1388, -0.1164, 0.2566, -0.4419, 0.6232, -0.5399], [-0.7720, 0.3316, 0.4893, 0.4173, 0.1900, 0.5988, 0.2782, -0.3852, 0.1218, -0.1172, -0.4391, 0.1240, 0.3925, 0.3963, -0.5687, 0.2115, 0.4115, 0.5132, -0.1591, -0.1080], [ 0.1837, 0.2649, 0.6524, 0.2677, 0.0456, 0.2033, -0.0522, 0.4843, -0.4531, 0.4153, 0.0187, -0.6308, 0.1819, -0.5004, 0.6018, 0.4021, 0.4913, -0.5287, 0.1526, -0.1455]], grad_fn=<TanhBackward>), tensor([[ 0.1872, -0.1069, 0.4237, 0.4201, -0.6734, 0.0836, -0.0252, 0.2273, -0.2810, -0.0137, -0.2922, -0.3051, -0.2602, -0.4907, 0.0777, 0.1137, 0.2030, -0.1614, -0.0779, -0.2083], [-0.0990, 0.3498, 0.5492, -0.3256, 0.2025, 0.3302, -0.5011, -0.1571, 0.0209, 0.2982, 0.1901, -0.6905, 0.2419, -0.5201, 0.3651, 0.3990, 0.5685, -0.4665, 0.0143, -0.1595], [-0.5264, -0.0514, 0.1115, 0.3346, -0.2498, -0.0302, 0.4115, 0.3076, -0.5988, -0.0438, -0.3437, 0.1128, 0.2481, -0.0956, -0.2785, -0.1713, 0.2296, -0.1200, 0.0860, -0.2926]], grad_fn=<TanhBackward>), tensor([[-0.2957, 0.1804, 0.3002, 0.0617, -0.1344, 0.1993, -0.3224, 0.4173, -0.0781, 0.3736, -0.2150, 0.2653, -0.0528, 0.0651, -0.0500, 0.2519, -0.0915, -0.2620, -0.2110, -0.5948], [-0.1506, 0.4123, 0.0162, 0.1171, 0.0414, -0.0956, -0.2576, 0.4046, -0.6677, -0.0049, -0.2525, -0.2696, 0.2976, -0.4672, -0.0190, 0.1525, 0.2290, -0.4887, 0.0049, -0.7503], [-0.2533, -0.2999, 0.0536, -0.4347, -0.4320, 0.2809, -0.2127, -0.5016, -0.2124, 0.3309, -0.4574, -0.1008, -0.1006, -0.2328, 0.3993, 0.0364, 0.6901, 0.1125, 0.4137, 0.6626]], grad_fn=<TanhBackward>), tensor([[-8.1422e-02, 1.8842e-02, 1.0144e-01, 3.8074e-02, -5.5781e-01, 3.0895e-01, -4.5643e-01, 1.6013e-01, -2.9781e-01, 1.9961e-01, -1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01, 3.6203e-01, 5.9737e-02, 2.3827e-01, -1.1631e-01, -7.2619e-02, -1.0607e-01], [-5.0548e-05, -1.7537e-01, -5.8831e-02, 1.4870e-01, -6.3641e-01, -8.8198e-02, 3.6017e-01, -7.5253e-02, -1.5342e-01, -1.3452e-01, -1.0299e-01, 5.4613e-02, 2.3648e-01, -4.3949e-01, 3.1918e-02, -1.2566e-01, 5.4776e-01, 1.7124e-01, 4.7584e-01, -1.5331e-01], [-2.6558e-01, 2.2919e-01, 5.0820e-01, 1.0909e-01, -4.1980e-01, -1.5907e-01, 2.2342e-01, 5.8021e-02, -2.7018e-01, -3.4850e-01, -1.8512e-01, 6.1239e-03, 2.4516e-01, -5.7344e-01, -7.5789e-02, 1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]], grad_fn=<TanhBackward>)] official_output tensor([[[ 1.2620e-01, 6.0072e-01, -7.1112e-02, 2.0916e-01, -1.7033e-01, -5.8128e-02, 3.4290e-01, 5.7120e-01, 7.6649e-04, 4.7431e-01, -9.7752e-02, -6.9819e-01, 4.4204e-02, 1.8705e-01, 3.7682e-01, -3.2877e-01, 3.3991e-01, -9.3203e-01, -4.5387e-01, -7.6271e-01], [ 6.4547e-01, -3.3936e-01, -6.3193e-01, -3.2661e-02, -5.8965e-01, -7.6409e-01, 1.3470e-01, -3.7835e-01, -2.0378e-01, -1.6322e-01, -6.0952e-01, 2.9986e-02, 7.8969e-02, -6.4902e-01, 7.5271e-01, 1.7919e-01, 6.5517e-01, -6.5625e-01, 3.2050e-01, -3.4623e-01], [-5.5407e-01, 2.0340e-01, 4.1821e-01, -9.7931e-02, 4.2492e-01, 8.5182e-01, 7.9682e-02, 7.5144e-01, -2.0973e-01, -1.3963e-01, 4.5111e-01, -5.1502e-01, -3.1101e-01, 8.7050e-02, 7.7077e-01, 4.9754e-01, -1.6914e-01, 5.5128e-01, -7.0215e-01, 2.6817e-01]], [[-2.9177e-01, 5.0127e-01, 3.3566e-02, -3.5687e-01, -5.7271e-01, -1.5774e-01, -1.7043e-01, 3.3525e-01, -4.6915e-01, -2.3995e-01, 3.7142e-01, 3.9644e-01, -2.2941e-01, 9.0899e-02, 1.3878e-01, -1.1636e-01, 2.5660e-01, -4.4189e-01, 6.2322e-01, -5.3986e-01], [-7.7200e-01, 3.3155e-01, 4.8930e-01, 4.1734e-01, 1.8999e-01, 5.9885e-01, 2.7816e-01, -3.8521e-01, 1.2183e-01, -1.1717e-01, -4.3911e-01, 1.2396e-01, 3.9253e-01, 3.9633e-01, -5.6871e-01, 2.1150e-01, 4.1146e-01, 5.1318e-01, -1.5914e-01, -1.0799e-01], [ 1.8367e-01, 2.6493e-01, 6.5243e-01, 2.6774e-01, 4.5578e-02, 2.0329e-01, -5.2159e-02, 4.8428e-01, -4.5313e-01, 4.1533e-01, 1.8746e-02, -6.3081e-01, 1.8190e-01, -5.0044e-01, 6.0178e-01, 4.0211e-01, 4.9127e-01, -5.2867e-01, 1.5256e-01, -1.4553e-01]], [[ 1.8723e-01, -1.0690e-01, 4.2369e-01, 4.2007e-01, -6.7342e-01, 8.3559e-02, -2.5240e-02, 2.2735e-01, -2.8096e-01, -1.3662e-02, -2.9221e-01, -3.0512e-01, -2.6019e-01, -4.9072e-01, 7.7736e-02, 1.1373e-01, 2.0299e-01, -1.6141e-01, -7.7901e-02, -2.0833e-01], [-9.8969e-02, 3.4982e-01, 5.4921e-01, -3.2558e-01, 2.0254e-01, 3.3020e-01, -5.0109e-01, -1.5706e-01, 2.0853e-02, 2.9821e-01, 1.9009e-01, -6.9054e-01, 2.4189e-01, -5.2012e-01, 3.6514e-01, 3.9902e-01, 5.6852e-01, -4.6647e-01, 1.4296e-02, -1.5953e-01], [-5.2637e-01, -5.1397e-02, 1.1150e-01, 3.3456e-01, -2.4977e-01, -3.0166e-02, 4.1154e-01, 3.0765e-01, -5.9878e-01, -4.3782e-02, -3.4375e-01, 1.1282e-01, 2.4812e-01, -9.5623e-02, -2.7851e-01, -1.7131e-01, 2.2957e-01, -1.1999e-01, 8.5984e-02, -2.9264e-01]], [[-2.9568e-01, 1.8038e-01, 3.0018e-01, 6.1720e-02, -1.3442e-01, 1.9932e-01, -3.2239e-01, 4.1725e-01, -7.8142e-02, 3.7360e-01, -2.1505e-01, 2.6528e-01, -5.2758e-02, 6.5120e-02, -4.9986e-02, 2.5186e-01, -9.1457e-02, -2.6198e-01, -2.1105e-01, -5.9480e-01], [-1.5058e-01, 4.1227e-01, 1.6235e-02, 1.1707e-01, 4.1378e-02, -9.5621e-02, -2.5761e-01, 4.0463e-01, -6.6765e-01, -4.8583e-03, -2.5254e-01, -2.6960e-01, 2.9760e-01, -4.6718e-01, -1.9016e-02, 1.5246e-01, 2.2903e-01, -4.8867e-01, 4.9081e-03, -7.5035e-01], [-2.5332e-01, -2.9992e-01, 5.3646e-02, -4.3469e-01, -4.3205e-01, 2.8094e-01, -2.1272e-01, -5.0162e-01, -2.1240e-01, 3.3086e-01, -4.5738e-01, -1.0083e-01, -1.0064e-01, -2.3278e-01, 3.9928e-01, 3.6350e-02, 6.9007e-01, 1.1249e-01, 4.1367e-01, 6.6261e-01]], [[-8.1422e-02, 1.8842e-02, 1.0144e-01, 3.8074e-02, -5.5781e-01, 3.0895e-01, -4.5643e-01, 1.6013e-01, -2.9781e-01, 1.9961e-01, -1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01, 3.6203e-01, 5.9737e-02, 2.3827e-01, -1.1631e-01, -7.2618e-02, -1.0607e-01], [-5.0455e-05, -1.7537e-01, -5.8831e-02, 1.4870e-01, -6.3641e-01, -8.8198e-02, 3.6017e-01, -7.5252e-02, -1.5342e-01, -1.3452e-01, -1.0299e-01, 5.4613e-02, 2.3648e-01, -4.3949e-01, 3.1918e-02, -1.2566e-01, 5.4776e-01, 1.7124e-01, 4.7584e-01, -1.5331e-01], [-2.6558e-01, 2.2919e-01, 5.0820e-01, 1.0909e-01, -4.1980e-01, -1.5907e-01, 2.2342e-01, 5.8021e-02, -2.7018e-01, -3.4850e-01, -1.8512e-01, 6.1239e-03, 2.4516e-01, -5.7344e-01, -7.5789e-02, 1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]]], grad_fn=<StackBackward>)
可見,兩個模型結果一致。手寫模型結構無誤。
結論: 通過理論和實踐代碼表明,RNN是同參時序神經網絡(同參,亦稱:權重共享),一層RNN有四個參數變量,每個timestep的樣本均與這同樣的四個參數變量計算。 高層依賴低一層的計算結果,當前時刻依賴前一時刻的計算結果。
本文若有不當之處,敬請批評指正。
也可加微信 (13681536300,請備注:ML)共同交流、共同進步。

