Seq2Seq在訓練階段和預測階段稍有差異。如果Decoder第一個預測預測的輸出就錯了,它會導致“蝴蝶效應“,影響后面全部內容。為了解決這個問題,在訓練時,Decoder每個時間步的輸入不全是上一個時間步的輸出,而以一定的概率選擇真實值作為輸入。
通常,Encoder的輸入序列需要添加一個終止符“<eos>”,可以不需要起始符“<sos>”。Decoder輸入序列在訓練時則需要添加一個起始符和終止符,在預測時,Decoder接收一個起始符“<sos>”,它類似一個信號,告訴Decoder可以開始工作了,當輸出終止符時我們就可以停下來(通常可以再設置一個最大輸出長度,防止Decoder一直不輸出終止符)。
終止符和起始符只要不會出現在原始序列中就可以了,也可以用<start>和<stop>,<bos>和<eos>,<s>和</s>等等
Attention機制
這里介紹的是LuongAttention
整個輸入序列的信息被Encoder編碼為固定長度的向量,類似”有損壓縮”。這個向量無法完全表達整個輸入序列的信息。另外,隨着輸入長度的增加,這個固定長度的向量,會逐漸丟失更多信息。
以英中翻譯任務為例,我們翻譯的時候,雖然要考慮上下文,但每個時間步的輸出,不同單詞的貢獻是不同的。考慮下面這個句子對:
She doesn't like soccer.
她不喜歡足球。
我們翻譯“她”時,其實只需要考慮“She”就好了,“足球”也是同理。簡單說,Attention機制讓我們的輸出時,關注輸入序列中的某一些部位就可以了,即讓輸入的單詞有不同的貢獻。
根據原始論文,我們定義以下符號:在每個時間步$t$,Decoder當前時間步的隱藏狀態$h_t$,整個Encoder輸出的隱藏狀態$\bar h_s$,權重數值$a_t$,上下文向量$c_t$。
注意力值通過以下方式計算:
$$
score(h_t,\bar h_s)=
\begin{cases}
h_t^T\bar h_s & \text{dot} \\
h_t^TW_a\bar h_s & \text{general} \\
v_a^T\tanh (W_a[h_t;\bar h_s]) & \text{concat}
\end{cases}
$$
其中權重根據以下公式計算(其實就是用softmax歸一化了)
$$
a_t(s)=align(h_t, \bar h_s)=\frac {exp(score(h_t, \bar h_s))}{\sum_{s'} exp(score(h_t, \bar h_{s'}))}
$$
上下文向量根據權重,對Encoder輸出隱藏狀態的每個時間步進行加權平均
$$
c_t=\sum_s a_t(s) \cdot \bar h_s
$$
與Decoder當前時間步的隱藏狀態拼接,計算一個注意力隱藏狀態,其計算公式如下
$$
\tilde h_t = \tanh (W_c[c_t;h_t])
$$
再根據這個注意力隱藏狀態預測輸出結果
$$
y = \text{softmax}(W_s\tilde h_t)
$$
部分代碼
參考了官方文檔和github上的一些代碼,使用Attention機制和不使用Attention機制的翻譯器都實現了一下。這里只對使用了Attention機制的翻譯器的部分代碼進行說明,完整代碼如下
https://gitee.com/dogecheng/python/blob/master/pytorch/Seq2SeqForTranslation.ipynb
在計算出注意力值后,Decoder將其與Encoder輸出的隱藏狀態進行加權平均,得到上下文向量context.
再將context與Decoder當前時間步的隱藏狀態拼接,經過tanh。最后用softmax預測最終的輸出概率。
class Decoder(nn.Module): def forward(self, token_inputs, last_hidden, encoder_outputs): ... # encoder_outputs = [input_lengths, batch, hid_dim * n directions] attn_weights = self.attn(gru_output, encoder_outputs) # attn_weights = [batch, 1, sql_len] context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # [batch, 1, hid_dim * n directions] gru_output = gru_output.squeeze(0) # [batch, n_directions * hid_dim] context = context.squeeze(1) # [batch, n_directions * hid_dim] concat_input = torch.cat((gru_output, context), 1) # [batch, n_directions * hid_dim * 2] concat_output = torch.tanh(self.concat(concat_input)) # [batch, n_directions*hid_dim] output = self.out(concat_output) # [batch, output_dim] output = self.softmax(output) ...
訓練時,根據use_teacher_forcing設置的閾值,決定下一時間步的輸入是上一時間步的預測結果還是來自數據的真實值
if self.predict: """ 預測代碼 """ ... else: max_target_length = max(target_lengths) all_decoder_outputs = torch.zeros((max_target_length, batch_size, self.decoder.output_dim), device=self.device) for t in range(max_target_length): use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False if use_teacher_forcing: # decoder_output = [batch, output_dim] # decoder_hidden = [n_layers*n_directions, batch, hid_dim] decoder_output, decoder_hidden, decoder_attn = self.decoder( decoder_input, decoder_hidden, encoder_outputs ) all_decoder_outputs[t] = decoder_output decoder_input = target_batches[t] # 下一個輸入來自訓練數據 else: decoder_output, decoder_hidden, decoder_attn = self.decoder( decoder_input, decoder_hidden, encoder_outputs ) # [batch, 1] topv, topi = decoder_output.topk(1) all_decoder_outputs[t] = decoder_output decoder_input = topi.squeeze(1).detach() # 下一個輸入來自模型預測
損失函數通過使用設置ignore_index不計padding部分的損失
loss_fn = nn.NLLLoss(ignore_index=PAD_token) loss = loss_fn( all_decoder_outputs.reshape(-1, self.decoder.output_dim), # [batch*seq_len, output_dim] target_batches.reshape(-1) # [batch*seq_len] )
Seq2Seq在預測階段每次只輸入一個樣本,輸出其翻譯結果,對應forward()函數中的內容如下,當Decoder輸出終止符或輸出長度達到所設定的閾值時便停止。
class Seq2Seq(nn.Module): ... def forward(self, input_batches, input_lengths, target_batches=None, target_lengths=None, teacher_forcing_ratio=0.5): ... if self.predict: # 一次只輸入一句話 assert batch_size == 1, "batch_size of predict phase must be 1!" output_tokens = [] while True: decoder_output, decoder_hidden, decoder_attn = self.decoder( decoder_input, decoder_hidden, encoder_outputs ) # [1, 1] topv, topi = decoder_output.topk(1) decoder_input = topi.squeeze(1).detach() output_token = topi.squeeze().detach().item() if output_token == EOS_token or len(output_tokens) == self.max_len: break output_tokens.append(output_token) return output_tokens else: """ 訓練代碼 """ ...
部分實驗結果,具體可以在notebook里看
參考資料
NLP FROM SCRATCH: TRANSLATION WITH A SEQUENCE TO SEQUENCE NETWORK AND ATTENTION
DEPLOYING A SEQ2SEQ MODEL WITH TORCHSCRIPT
Practical PyTorch: Translation with a Sequence to Sequence Network and Attention