QA系統Match-LSTM代碼研讀


QA系統Match-LSTM代碼研讀

背景

在QA模型中,Match-LSTM是較早提出的,使用Prt-Net邊界模型。本文是對閱讀其實現代碼的總結。主要思路是對照着論文和代碼,對論文中模型的關鍵結構,查看代碼中的具體實現。參考代碼是MurtyShikhar實現的

模型簡介

模型的輸入是(Passage, Question),模型的輸出是(start_idx, end_idx)。對於輸入,Passage是QA任務中的正文,輸入給模型時已經轉化為經過Padding的id-list;Question是QA任務中的問題,輸入給模型時已經轉化為經過Padding的id-list。對於輸出,start_idx是答案在正文的起始位置,end_idx是答案在正文的結束位置。

用於QA的Match-LSTM模型主要由三層構成:

  1. LSTM預處理層。
    分別將Passage和Question通過LSTM進行處理,使每個位置的表示都帶有一些上下文信息。
  2. Match-LSTM層。
    Match-LSTM最早用於文本蘊含,輸入一個前提,一個猜測,判斷前提是否能蘊含猜測。在用於QA任務時,Question被當做前提,Passage被當做猜測。依次處理Passage的每個位置,計算Passage每個位置對Question的Attention,進而求出對Question的Attend Vector。該Attend Vector與第一層的輸出拼接起來,輸入給一個LSTM進行處理,這整個流程被稱作Match-LSTM。
    其中Attention選擇BahdanauAttention,Attention的輸入(Query)由上一時刻Match-LSTM的輸出及Passage在當前位置的表示拼接,Attention的key是Question每個位置的表示,Attention的value也是Question每個位置的表示。根據Attention的alignment對Attention Value加權求和計算出Attend Vector。
    所以,Match-LSTM本質上由一個LSTM單元和一個Attention單元組成。LSTM單元的輸出作為Match-LSTM層的輸出,LSTM單元的狀態和下一個位置的輸入拼接起來作為Attention單元的輸入(Query),Attention單元的輸出(Attend Vector)與當前位置的輸入拼接起來作為LSTM單元的輸入。也可以理解為在LSTM的基礎上增加Attention,改變LSTM的輸入,在LSTM的原始輸入上增加當前位置對於Question的Attention。
  3. Pointer-Net層。
    Pointer-Net層在代碼實現上,與Match-LSTM十分接近。只在涉及輸入、輸出的地方有幾處不同。從原理上看,Pointer-Net層也是一個序列化迭代的Attention過程,首先用zero_state作為query對Match-LSTM層的所有輸出計算attention,作為回答第一個符號的logit。然后以AttentionWrapper的輸出作為下一時刻的query,對Match-LSTM層的所有輸出計算attention,如此迭代進行。對於邊界模型,秩序計算start_index和end_index,這個迭代過程秩序進行兩次。

接下來的幾部分對照論文及代碼中模型關鍵結構實現。

模型

模型圖構建的入口在qa_model.py文件中class QASystem類的def setup_system(self)方法內。這一節主要就是對該方法的細節展開解讀。

LSTM預處理層

所有邏輯都包含在qa_model.py文件中,入口位於class QASystem類的def setup_system(self)方法內,具體邏輯位於class Encoderdef encode(self, inputs, masks, encoder_state_input = None)方法內。

def setup_system(self)方法內,通過以下語句調用class Encoderdef encode(self, inputs, masks, encoder_state_input = None)方法。

encoder = self.encoder
decoder = self.decoder
encoded_question, encoded_passage, q_rep, p_rep = encoder.encode([self.question, self.passage],
				 [self.question_lengths, self.passage_lengths], encoder_state_input = None)

再看一下encode方法的實現。

def encode(self, inputs, masks, encoder_state_input = None):
    """
    :param inputs: vector representations of question and passage (a tuple) 
    :param masks: masking sequences for both question and passage (a tuple)
    :param encoder_state_input: (Optional) pass this as initial hidden state to tf.nn.dynamic_rnn to build conditional representations
    :return: an encoded representation of the question and passage.
    """
    
    question, passage = inputs
    masks_question, masks_passage = masks

    # read passage conditioned upon the question
    with tf.variable_scope("encoded_question"):
        lstm_cell_question = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
        encoded_question, (q_rep, _) = tf.nn.dynamic_rnn(lstm_cell_question, question, masks_question, dtype=tf.float32) # (-1, 
Q, H)

    with tf.variable_scope("encoded_passage"):
        lstm_cell_passage  = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
        encoded_passage, (p_rep, _) =  tf.nn.dynamic_rnn(lstm_cell_passage, passage, masks_passage, dtype=tf.float32) # (-1, P, 
H)
    # outputs beyond sequence lengths are masked with 0s
    return encoded_question, encoded_passage , q_rep, p_rep

從代碼可以看出,對Passage和Question的預處理就是分別經過兩個單向LSTM層(不共享參數),LSTM每個位置的輸出作為預處理后的表示。

Match-LSTM層

Match-LSTM的邏輯主要在qa_model.py和attention_wrapper.py兩個文件中。雖然tensorflow的contrib庫中現在也有attention_wrapper這個模塊,但是兩者在具體實現上不太相同。入口位於qa_model.py文件class Decoder類中decode方法內。

首先,看一下最外層的入口,與LSTM預處理層一樣,位於class QASystem類的def setup_system(self)方法內。

if self.config.use_match:
    self.logger.info("\n========Using Match LSTM=========\n")
    logits= decoder.decode([encoded_question, encoded_passage], q_rep, [self.question_lengths, self.passage_lengths], self.
labels)

接下來,進入class Decoder類中decode方法。函數邏輯非常清晰,先通過Match-LSTM層,再通過Ptr-Net層。

def decode(self, encoded_rep, q_rep, masks, labels):
    output_attender = self.run_match_lstm(encoded_rep, masks)
    logits = self.run_answer_ptr(output_attender, masks, labels)

    return logits

然后進入run_match_lstm方法。

def run_match_lstm(self, encoded_rep, masks):
    encoded_question, encoded_passage = encoded_rep
    masks_question, masks_passage = masks

    match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)
    query_depth = encoded_question.get_shape()[-1]


    # output attention is false because we want to output the cell output and not the attention values
    with tf.variable_scope("match_lstm_attender"):
        attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)
        cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
        lstm_attender  = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)

        # we don't mask the passage because masking the memories will be handled by the pointerNet
        reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0)

        output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn")
        output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn")

        output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)


    output_attender = tf.concat([output_attender_fw, output_attender_bw], axis = -1) # (-1, P, 2*H)
    return output_attender

該方法的輸入encoded_rep是一個tuple,包含Passage和Question的表示;masks也是一個tuple,包含Passage和Question的長度。

match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)

這條語句定義了Match-LSTM單元中AttentionMechanism的輸入函數,作為參數該函數被傳遞給AttentionWrapper的構造函數,作為attention_input_fn AttentionWrapper本身也是一個RNN,它組合了一個RNN和一個AttentionMechanism,形成一個高級的RNN單元。該函數就是定義了用於Attention機制的Query是如何生成的,由當前時刻的輸入拼接上一個時刻的state,形成Attention的Query。

attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)

這條語句定義了一個AttentionMechanism,也就是一個Attention單元,該類包含一個__call__方法,調用該對象可以計算出alignments,調用該類對象的參數如方法定義所示def __call__(self, query, previous_alignments)。聯系上面一起來看,這里的query就是上面所說的Attention的Query。
至於BahdanauAttention是如何實現的,暫時不做過詳細的介紹,目前該類位於tf.contrib.seq2seq.BahdanauAttention,已經是tensorflow庫的一部分。

cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)

這條語句定義一個普通的LSTM單元。

lstm_attender  = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)

這條語句將上面兩步定義的AttentionMechanism及LSTM單元組裝為一個高級RNN單元。參數還包括了在run_match_lstm方法一開頭頂一個的一個函數,該函數用來生成AttentionMechanismquery

reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0)
output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn")
output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn")
output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)

分別正向、反向對Passage的表示應用Match-LSTM,再將輸出沿最后一個維度拼接起來作為Match-LSTM層的輸出。

我們還可以再近距離看一下LSTM單元和AttentionMechanism是如何配合工作的,這需要深入到AttentionWrappercall方法,這也是所有RNN單元都需要實現的一個方法。

def call(self, inputs, state):
	output_prev_step = state.cell_state.h # get hr_(i-1)
    attention_input = self._attention_input_fn(inputs, output_prev_step) # get input to BahdanauAttention to get alpha_i
    alignments, raw_scores = self._attention_mechanism(
        attention_input, previous_alignments=state.alignments)

    expanded_alignments = array_ops.expand_dims(alignments, 1)

    attention_mechanism_values = self._attention_mechanism.values
    context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
    context = array_ops.squeeze(context, [1])


    cell_inputs = self._cell_input_fn(inputs, context) #concatenate input with alpha*memory and feed into root LSTM
    cell_state = state.cell_state
    cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

    if self._attention_layer is not None:
      attention = self._attention_layer(
          array_ops.concat([cell_output, context], 1))
    else:
      attention = context

    if self._alignment_history:
      alignment_history = state.alignment_history.write(
          state.time, alignments)
    else:
      alignment_history = ()

    next_state = AttentionWrapperState(
        time=state.time + 1,
        cell_state=next_cell_state,
        attention=attention,
        alignments=alignments,
        alignment_history=alignment_history)

    if self._output_attention:
      return raw_scores, next_state
    else:
      return cell_output, next_state
output_prev_step = state.cell_state.h # get hr_(i-1)
attention_input = self._attention_input_fn(inputs, output_prev_step)

取LSTM單元上一時刻的狀態,與AttentionWrapper當前時刻的輸入,通過self._attention_input_fn函數生成attention的Query。這里的self._attention_input_fn就是上面AttentionWrapper構造函數的參數attention_input_fn

alignments, raw_scores = self._attention_mechanism(attention_input, previous_alignments=state.alignments)

調用AttentionMechaism對象,計算Attention的alignments。這里的self._attention_mechanism就是AttentionWrapper構造函數的參數attention_mechanism_match_lstm,也就是BahdanauAttention的一個對象。

expanded_alignments = array_ops.expand_dims(alignments, 1)       # [batch_size, 1, ques_size]
attention_mechanism_values = self._attention_mechanism.values   # [batch_size, ques_size, value_dims]
context = math_ops.matmul(expanded_alignments, attention_mechanism_values) # [batch_size, 1, value_dims]
context = array_ops.squeeze(context, [1])   # [batch_size, value_dims]

通過alignments和attention的Values,計算attend vector,就是對values以alignments為權重求和。

cell_inputs = self._cell_input_fn(inputs, context) #concatenate input with alpha*memory and feed into root LSTM
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

通過_cell_input_fn將當前時刻的輸入,和attend vector組合起來,成為當前時刻LSTM的輸入。然后調用LSTM單元計算當前時刻LSTM單元的輸出和狀態。

if self._attention_layer is not None:
   attention = self._attention_layer(
                     array_ops.concat([cell_output, context], 1))
else:
   attention = context

是否需要對attend vector再進行一次線性變換,作為attention,在本例中未做變換,直接用attend vector作為attention。

next_state = AttentionWrapperState(
      time=state.time + 1,
      cell_state=next_cell_state,
      attention=attention,
      alignments=alignments,
      alignment_history=alignment_history)

作為RNN的AttentionWrapper的下一時刻狀態。

if self._output_attention:
    return raw_scores, next_state
  else:
    return cell_output, next_state

根據構造函數的參數,決定AttentionWrapper的輸出是attention score還是LSTM的輸出,attention score的意義是求alignments概率之前的那個東西。

Pointer-Net層

以下代碼是Pointer-Net層的邏輯,與Match-LSTM層的邏輯非常接近,但是在一些細節上有所區別。相似的部分是,Pointer-Net層的主體也是通過一個AttentionWrapper完成的,也是組裝了一個LSTM單元和一個BahdanauAttention單元。與Match-LSTM不同的地方是,LSTM單元及BahdanauAttention單元的輸入函數不同,AttentionWrapper的輸出內容不同,並且Pointer-Net層使用一個靜態rnn。

def run_answer_ptr(self, output_attender, masks, labels):
    batch_size = tf.shape(output_attender)[0]
    masks_question, masks_passage = masks
    labels = tf.unstack(labels, axis=1) 
    #labels = tf.ones([batch_size, 2, 1])


    answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question
    query_depth_answer_ptr = output_attender.get_shape()[-1]

    with tf.variable_scope("answer_ptr_attender"):
        attention_mechanism_answer_ptr = BahdanauAttention(query_depth_answer_ptr , output_attender, memory_sequence_length = masks_passage)
        # output attention is true because we want to output the attention values
        cell_answer_ptr = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True )
        answer_ptr_attender = AttentionWrapper(cell_answer_ptr, attention_mechanism_answer_ptr, cell_input_fn = answer_ptr_cell_input_fn)
        logits, _ = tf.nn.static_rnn(answer_ptr_attender, labels, dtype = tf.float32)

        return logits 

接下來具體看一下上面這段代碼。

batch_size = tf.shape(output_attender)[0]       # [batch_size, passage_length, 2 * hidden_size]
masks_question, masks_passage = masks
labels = tf.unstack(labels, axis=1)     # labels : [batch_size, 2]

output_attender是上一層,也就是Match-LSTM層的輸出,形狀為[batch_size, passage_length, 2 * hidden_size]labels的形狀為[batch_size, 2]masks_questionmasks_passage分別為問題的長度和文章的長度。

answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question
query_depth_answer_ptr = output_attender.get_shape()[-1]

answer_ptr_cell_input_fn 定義了AttentionWrapperLSTM單元的輸入函數。query_depth_answer_ptr從變量名的字面含義看,是Answer-Ptr層的attention單元的query的維度。

with tf.variable_scope("answer_ptr_attender"):
   attention_mechanism_answer_ptr = BahdanauAttention(query_depth_answer_ptr , output_attender, memory_sequence_length = masks_passage)
   # output attention is true because we want to output the attention values
   cell_answer_ptr = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True )
   answer_ptr_attender = AttentionWrapper(cell_answer_ptr, attention_mechanism_answer_ptr, cell_input_fn = answer_ptr_cell_input_fn)

接下來是裝配AttentionWrapper這里與Match-LSTM層有區別。在Match-LSTM層的定義中,沒有顯式地為AttentionWrapper指定cell_input_fn參數,而是使用了默認函數。在Match-LSTM層的定義中,顯式指定了attention_input_fn,但是這里沒有指定,使用了默認函數。另外一個區別,在Match-LSTM層的定義中,AttentionWrapperoutput_attention參數是False,在這里該參數用默認的True

對比Match-LSTM層與Pointer-Net層cell_input_fn的區別。

默認的cell_input_fn的定義如下,這是Match-LSTM層采用的。邏輯是將attention的輸出和當前的輸入拼接起來,作為LSTM單元的輸入。

if cell_input_fn is None:
   cell_input_fn = ( 
       lambda inputs, attention: array_ops.concat([inputs, attention], -1))

Pointer-Net層使用的cell_input_fn在上面的代碼中已經給出,這里對比一下。只用Attention單元的輸出,作為LSTM單元的輸入。這樣,LSTM單元的輸入,就與RNN的輸入無關了。

answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question

對比Match-LSTM層與Pointer-Net層attention_input_fn的區別。

Match-LSTM層采用的attention_input_fn是非默認的,在上一節中已經給出,這里對比一下。

match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)

Pointer-Net層的attention_input_fn是默認的,定義如下。

if attention_input_fn is None:
   attention_input_fn = ( 
       lambda _, state: state)

可以看出,在Match-LSTM層,attention單元的輸入是上一時刻狀態與當前輸入的拼接。在Pointer-Net層,attention單元的輸入僅僅是上一時刻的狀態,與當前時刻的輸入無關。

綜上兩處,可以看出區別。在Match-LSTM層,無論Attention單元還是LSTM單元,其輸入都要拼接當前時刻輸入。而在Pointer-Net層,無論Attention單元還是LSTM單元,其輸入都與當前時刻的輸入無關。這也解釋了我最早看代碼時的疑惑,為什么計算logits的函數需要labels作為參數,labels不是只有在計算loss的時候才需要嗎?其實雖然這里有labels這個參數,但是沒有實際使用其內容,對於預測過程,只需傳一個同樣形狀的tensor就可以。

再對比最后一個區別,Match-LSTM層與Pointer-Net層在output_attention參數上的區別。

if self._output_attention:
    return raw_scores, next_state
else:
    return cell_output, next_state

raw_scoresattention單元的原始輸出,即通過softmax計算alignments前的那個輸出。cell_outputLSTM單元的輸出,也就是狀態h。在Match-LSTM層,AttentionWrapper輸出的是其內部LSTM單元的輸出。在Pointer-Net層,AttentionWrapper輸出的是其內部attention單元的raw_scores

logits, _ = tf.nn.static_rnn(answer_ptr_attender, labels, dtype = tf.float32)

最后是計算logits。因為labels是個長度為2的listlogits也是長度為2的list。但是,這兩個list中元素的shape是不一樣的,labels中的元素的shape[batch_size, 1],logits中的元素的shape[batch_size, passage_length]
從代碼層面來理解,首先是以zero_state為query去計算attention,attention單元的key和value都是Match-LSTM層的輸出,attention計算的raw_score就是第一個輸出的logitattention計算出的alignmentsvalues計算attend vector,以其為輸入計算LSTM單元的輸出,作為下一時刻的query去計算attention。這樣,就計算出了兩個logits

至此,計算出logits,預測部分就已經完成了。logits是一個長度為2的list,其中每個元素是一個shape[batch_size, passage_length]的tensor。

損失函數

有了logits,就可以計算損失函數了。

losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits[0], labels=self.labels[:,0])
losses += tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits[1], labels=self.labels[:,1])
self.loss = tf.reduce_mean(losses)

這里只需要理解一個函數即可tf.nn.sparse_softmax_cross_entropy_with_logits,該函數logits參數的ranklabels多1,多出的那個axis的維度是num_classeslabels以稀疏形式表示,每個元素都是整數,小於num_classes

由於之前已經知道,Pointer-Net層求出logits是一個list,每個元素的形狀是[batch_size, passage_length],而輸入的labels的形狀是[batch_size, 2]。因此按照上面代碼的方式調用可求出損失函數。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM