Pytorch LSTM/GRU更新h0, c0


LSTM隱層狀態h0, c0通常初始化為0,大部分情況下模型也能工作的很好。但是有時將h0, c0作為隨機值,或直接作為模型參數的一部分進行優化似乎更為合理。

這篇post給出了經驗證明:

Non-Zero Initial States for Recurrent Neural Networks

給出的經驗結果:

給出的結論是:(1)非零的初始狀態初始化能夠加速訓練並改善模型泛化性能,(2)將初始狀態作為模型參數去訓練要比具有零均值的噪聲初始化更有效, (3)如果選擇學習隱層初始狀態,添加噪聲並不能帶來額外的收益。

基本上,如果你的數據包括許多短序列,那么訓練初始狀態可以加速學習。相反,如果數據僅包含少量的長序列,那么可能沒有足夠的數據來有效地訓練初始狀態;在這種情況下,使用一個有噪聲的初始狀態可以加速學習。他們沒有提到的一個想法是如何恰當地確定隨機噪聲發生器的均值和std。此外,這篇文章Forecasting with Recurrent Neural Networks: 12 Tricks 中的Trick 4提出了一種基於反向傳播誤差的自適應方法,使初始狀態噪聲的大小根據反向傳播的誤差自適應變化。

 

 

 

 

 

 實際效果有待進一步驗證。

 

事實上,LSTM的隱藏層初始狀態h0, c0可以看做是模型的一部分參數,並在迭代中更新。這里給出pytorch中LSTM更新隱藏層初始狀態h0, c0的一種實現方法(來自知乎)。

 1 作者:鄭華濱
 2 鏈接:https://www.zhihu.com/question/270772480/answer/358198157
 3 來源:知乎
 4 著作權歸作者所有。商業轉載請聯系作者獲得授權,非商業轉載請注明出處。
 5 
 6 import torch
 7 import torch.nn as nn
 8 from torch.autograd import Variable
 9 
10 class EasyLSTM(nn.LSTM):
11 
12     def __init__(self, *args, **kwargs):
13         nn.LSTM.__init__(self, *args, **kwargs)
14         self.num_direction = 1 + self.bidirectional
15         state_size = (self.num_layers * self.num_direction, 1, self.hidden_size)
16         self.init_h = nn.Parameter(torch.zeros(state_size))
17         self.init_c = nn.Parameter(torch.zeros(state_size))
18 
19     def forward(self, rnn_input, prev_states = None):
20         batch_size = rnn_input.size(1)
21         if prev_states is None:
22             state_size = (self.num_layers * self.num_direction, batch_size, self.hidden_size)
23             init_h = self.init_h.expand(*state_size).contiguous()
24             init_c = self.init_c.expand(*state_size).contiguous()
25             prev_states = (init_h, init_c)
26         rnn_output, states = nn.LSTM.forward(self, rnn_input, prev_states)
27         return rnn_output, states

 

LSTM、GRU、LSTMCell、GRUCell ?

  1 import torch
  2 import torch.nn as nn
  3 from torch.autograd import Variable
  4 
  5 class EasyLSTM(nn.LSTM):
  6 
  7     def __init__(self, *args, **kwargs):
  8         nn.LSTM.__init__(self, *args, **kwargs)
  9         self.num_direction = 1 + self.bidirectional
 10         state_size = (self.num_layers * self.num_direction, 1, self.hidden_size)
 11         self.init_h = nn.Parameter(torch.zeros(state_size))
 12         self.init_c = nn.Parameter(torch.zeros(state_size))
 13 
 14     def forward(self, rnn_input, prev_states = None):
 15         batch_size = rnn_input.size(1)
 16         if prev_states is None:
 17             state_size = (self.num_layers * self.num_direction, batch_size, self.hidden_size)
 18             init_h = self.init_h.expand(*state_size).contiguous()
 19             init_c = self.init_c.expand(*state_size).contiguous()
 20             prev_states = (init_h, init_c)
 21         rnn_output, states = nn.LSTM.forward(self, rnn_input, prev_states)
 22         return rnn_output, states
 23 
 24 class EasyGRU(nn.GRU):
 25 
 26     def __init__(self, *args, **kwargs):
 27         nn.GRU.__init__(self, *args, **kwargs)
 28         self.num_direction = 1 + self.bidirectional
 29         state_size = (self.num_layers * self.num_direction, 1, self.hidden_size)
 30         self.init_h = nn.Parameter(torch.zeros(state_size))
 31 
 32     def forward(self, rnn_input, prev_states = None):
 33         batch_size = rnn_input.size(1)
 34         if prev_states is None:
 35             state_size = (self.num_layers * self.num_direction, batch_size, self.hidden_size)
 36             init_h = self.init_h.expand(*state_size).contiguous()
 37             prev_states = init_h
 38         rnn_output, states = nn.GRU.forward(self, rnn_input, prev_states)
 39         return rnn_output, states
 40 
 41 
 42 class EasyLSTMCell(nn.LSTMCell):
 43 
 44     def __init__(self, *args, **kwargs):
 45         nn.LSTMCell.__init__(self, *args, **kwargs)
 46         state_size = (1, self.hidden_size)
 47         self.init_h = nn.Parameter(torch.zeros(state_size))
 48         self.init_c = nn.Parameter(torch.zeros(state_size))
 49 
 50     def forward(self, rnn_input, prev_states=None):
 51         batch_size = rnn_input.size(0)
 52         if prev_states is None:
 53             state_size = (batch_size, self.hidden_size)
 54             init_h = self.init_h.expand(*state_size).contiguous()
 55             init_c = self.init_c.expand(*state_size).contiguous()
 56             prev_states = (init_h, init_c)
 57         h, c = nn.LSTMCell.forward(self, rnn_input, prev_states)
 58         return h, c
 59 
 60 
 61 class EasyGRUCell(nn.GRUCell):
 62 
 63     def __init__(self, *args, **kwargs):
 64         nn.GRUCell.__init__(self, *args, **kwargs)
 65         state_size = (1, self.hidden_size)
 66         self.init_h = nn.Parameter(torch.zeros(state_size))
 67 
 68     def forward(self, rnn_input, prev_states=None):
 69         batch_size = rnn_input.size(0)
 70         if prev_states is None:
 71             state_size = (batch_size, self.hidden_size)
 72             init_h = self.init_h.expand(*state_size).contiguous()
 73             prev_states = init_h
 74         h = nn.GRUCell.forward(self, rnn_input, prev_states)
 75         return h
 76 
 77 if __name__ == '__main__':
 78 
 79     lstm = EasyLSTM(10, 20, 2)
 80     input = torch.randn(5, 3, 10)
 81     h0 = torch.randn(2, 3, 20)
 82     c0 = torch.randn(2, 3, 20)
 83     output, (hn, cn) = lstm(input, (h0, c0))
 84 
 85     gru = EasyGRU(10, 20, 2)
 86     input = torch.randn(5, 3, 10)
 87     h0 = torch.randn(2, 3, 20)
 88     output, hn = gru(input, h0)
 89 
 90     lstmcell = EasyLSTMCell(10, 20)
 91     input = torch.randn(6, 3, 10)
 92     h = torch.randn(3, 20)
 93     c = torch.randn(3, 20)
 94     out = []
 95     for i in range(6):
 96         h, c = lstmcell(input[i], (h, c))
 97         out.append(h)
 98 
 99     grucell = EasyGRUCell(10, 20)
100     input = torch.randn(6, 3, 10)
101     h = torch.randn(3, 20)
102     out = []
103     for i in range(6):
104         h = grucell(input[i], h)
105         out.append(h)

 

 

參考:

Non-Zero Initial States for Recurrent Neural Networks

pytorch LSTM更新h0, c0

Best way to initialize LSTM state

https://danijar.com/tips-for-training-recurrent-neural-networks/

 
        

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM