QA系統Match-LSTM代碼研讀
背景
在QA模型中,Match-LSTM是較早提出的,使用Prt-Net邊界模型。本文是對閱讀其實現代碼的總結。主要思路是對照着論文和代碼,對論文中模型的關鍵結構,查看代碼中的具體實現。參考代碼是MurtyShikhar實現的。
模型簡介
模型的輸入是(Passage, Question),模型的輸出是(start_idx, end_idx)。對於輸入,Passage是QA任務中的正文,輸入給模型時已經轉化為經過Padding的id-list;Question是QA任務中的問題,輸入給模型時已經轉化為經過Padding的id-list。對於輸出,start_idx是答案在正文的起始位置,end_idx是答案在正文的結束位置。
用於QA的Match-LSTM模型主要由三層構成:
- LSTM預處理層。
分別將Passage和Question通過LSTM進行處理,使每個位置的表示都帶有一些上下文信息。- Match-LSTM層。
Match-LSTM最早用於文本蘊含,輸入一個前提,一個猜測,判斷前提是否能蘊含猜測。在用於QA任務時,Question被當做前提,Passage被當做猜測。依次處理Passage的每個位置,計算Passage每個位置對Question的Attention,進而求出對Question的Attend Vector。該Attend Vector與第一層的輸出拼接起來,輸入給一個LSTM進行處理,這整個流程被稱作Match-LSTM。
其中Attention選擇BahdanauAttention,Attention的輸入(Query)由上一時刻Match-LSTM的輸出及Passage在當前位置的表示拼接,Attention的key是Question每個位置的表示,Attention的value也是Question每個位置的表示。根據Attention的alignment對Attention Value加權求和計算出Attend Vector。
所以,Match-LSTM本質上由一個LSTM單元和一個Attention單元組成。LSTM單元的輸出作為Match-LSTM層的輸出,LSTM單元的狀態和下一個位置的輸入拼接起來作為Attention單元的輸入(Query),Attention單元的輸出(Attend Vector)與當前位置的輸入拼接起來作為LSTM單元的輸入。也可以理解為在LSTM的基礎上增加Attention,改變LSTM的輸入,在LSTM的原始輸入上增加當前位置對於Question的Attention。- Pointer-Net層。
Pointer-Net層在代碼實現上,與Match-LSTM十分接近。只在涉及輸入、輸出的地方有幾處不同。從原理上看,Pointer-Net層也是一個序列化迭代的Attention過程,首先用zero_state作為query對Match-LSTM層的所有輸出計算attention,作為回答第一個符號的logit。然后以AttentionWrapper的輸出作為下一時刻的query,對Match-LSTM層的所有輸出計算attention,如此迭代進行。對於邊界模型,秩序計算start_index和end_index,這個迭代過程秩序進行兩次。
接下來的幾部分對照論文及代碼中模型關鍵結構實現。
模型
模型圖構建的入口在qa_model.py文件中class QASystem類的def setup_system(self)方法內。這一節主要就是對該方法的細節展開解讀。
LSTM預處理層
所有邏輯都包含在qa_model.py文件中,入口位於class QASystem類的def setup_system(self)方法內,具體邏輯位於class Encoder的def encode(self, inputs, masks, encoder_state_input = None)方法內。
在def setup_system(self)方法內,通過以下語句調用class Encoder的def encode(self, inputs, masks, encoder_state_input = None)方法。
encoder = self.encoder
decoder = self.decoder
encoded_question, encoded_passage, q_rep, p_rep = encoder.encode([self.question, self.passage],
[self.question_lengths, self.passage_lengths], encoder_state_input = None)
再看一下encode方法的實現。
def encode(self, inputs, masks, encoder_state_input = None):
"""
:param inputs: vector representations of question and passage (a tuple)
:param masks: masking sequences for both question and passage (a tuple)
:param encoder_state_input: (Optional) pass this as initial hidden state to tf.nn.dynamic_rnn to build conditional representations
:return: an encoded representation of the question and passage.
"""
question, passage = inputs
masks_question, masks_passage = masks
# read passage conditioned upon the question
with tf.variable_scope("encoded_question"):
lstm_cell_question = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
encoded_question, (q_rep, _) = tf.nn.dynamic_rnn(lstm_cell_question, question, masks_question, dtype=tf.float32) # (-1,
Q, H)
with tf.variable_scope("encoded_passage"):
lstm_cell_passage = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
encoded_passage, (p_rep, _) = tf.nn.dynamic_rnn(lstm_cell_passage, passage, masks_passage, dtype=tf.float32) # (-1, P,
H)
# outputs beyond sequence lengths are masked with 0s
return encoded_question, encoded_passage , q_rep, p_rep
從代碼可以看出,對Passage和Question的預處理就是分別經過兩個單向LSTM層(不共享參數),LSTM每個位置的輸出作為預處理后的表示。
Match-LSTM層
Match-LSTM的邏輯主要在qa_model.py和attention_wrapper.py兩個文件中。雖然tensorflow的contrib庫中現在也有attention_wrapper這個模塊,但是兩者在具體實現上不太相同。入口位於qa_model.py文件class Decoder類中decode方法內。
首先,看一下最外層的入口,與LSTM預處理層一樣,位於class QASystem類的def setup_system(self)方法內。
if self.config.use_match:
self.logger.info("\n========Using Match LSTM=========\n")
logits= decoder.decode([encoded_question, encoded_passage], q_rep, [self.question_lengths, self.passage_lengths], self.
labels)
接下來,進入class Decoder類中decode方法。函數邏輯非常清晰,先通過Match-LSTM層,再通過Ptr-Net層。
def decode(self, encoded_rep, q_rep, masks, labels):
output_attender = self.run_match_lstm(encoded_rep, masks)
logits = self.run_answer_ptr(output_attender, masks, labels)
return logits
然后進入run_match_lstm方法。
def run_match_lstm(self, encoded_rep, masks):
encoded_question, encoded_passage = encoded_rep
masks_question, masks_passage = masks
match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)
query_depth = encoded_question.get_shape()[-1]
# output attention is false because we want to output the cell output and not the attention values
with tf.variable_scope("match_lstm_attender"):
attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)
cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
lstm_attender = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)
# we don't mask the passage because masking the memories will be handled by the pointerNet
reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0)
output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn")
output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn")
output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)
output_attender = tf.concat([output_attender_fw, output_attender_bw], axis = -1) # (-1, P, 2*H)
return output_attender
該方法的輸入
encoded_rep是一個tuple,包含Passage和Question的表示;masks也是一個tuple,包含Passage和Question的長度。match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)這條語句定義了
Match-LSTM單元中AttentionMechanism的輸入函數,作為參數該函數被傳遞給AttentionWrapper的構造函數,作為attention_input_fn。AttentionWrapper本身也是一個RNN,它組合了一個RNN和一個AttentionMechanism,形成一個高級的RNN單元。該函數就是定義了用於Attention機制的Query是如何生成的,由當前時刻的輸入拼接上一個時刻的state,形成Attention的Query。attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)這條語句定義了一個
AttentionMechanism,也就是一個Attention單元,該類包含一個__call__方法,調用該對象可以計算出alignments,調用該類對象的參數如方法定義所示def __call__(self, query, previous_alignments)。聯系上面一起來看,這里的query就是上面所說的Attention的Query。
至於BahdanauAttention是如何實現的,暫時不做過詳細的介紹,目前該類位於tf.contrib.seq2seq.BahdanauAttention,已經是tensorflow庫的一部分。cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)這條語句定義一個普通的LSTM單元。
lstm_attender = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)這條語句將上面兩步定義的
AttentionMechanism及LSTM單元組裝為一個高級RNN單元。參數還包括了在run_match_lstm方法一開頭頂一個的一個函數,該函數用來生成AttentionMechanism的query。reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0) output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn") output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn") output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)分別正向、反向對Passage的表示應用Match-LSTM,再將輸出沿最后一個維度拼接起來作為Match-LSTM層的輸出。
我們還可以再近距離看一下LSTM單元和AttentionMechanism是如何配合工作的,這需要深入到AttentionWrapper的call方法,這也是所有RNN單元都需要實現的一個方法。
def call(self, inputs, state):
output_prev_step = state.cell_state.h # get hr_(i-1)
attention_input = self._attention_input_fn(inputs, output_prev_step) # get input to BahdanauAttention to get alpha_i
alignments, raw_scores = self._attention_mechanism(
attention_input, previous_alignments=state.alignments)
expanded_alignments = array_ops.expand_dims(alignments, 1)
attention_mechanism_values = self._attention_mechanism.values
context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
context = array_ops.squeeze(context, [1])
cell_inputs = self._cell_input_fn(inputs, context) #concatenate input with alpha*memory and feed into root LSTM
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
if self._attention_layer is not None:
attention = self._attention_layer(
array_ops.concat([cell_output, context], 1))
else:
attention = context
if self._alignment_history:
alignment_history = state.alignment_history.write(
state.time, alignments)
else:
alignment_history = ()
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
alignments=alignments,
alignment_history=alignment_history)
if self._output_attention:
return raw_scores, next_state
else:
return cell_output, next_state
output_prev_step = state.cell_state.h # get hr_(i-1) attention_input = self._attention_input_fn(inputs, output_prev_step)取LSTM單元上一時刻的狀態,與
AttentionWrapper當前時刻的輸入,通過self._attention_input_fn函數生成attention的Query。這里的self._attention_input_fn就是上面AttentionWrapper構造函數的參數attention_input_fn。alignments, raw_scores = self._attention_mechanism(attention_input, previous_alignments=state.alignments)調用
AttentionMechaism對象,計算Attention的alignments。這里的self._attention_mechanism就是AttentionWrapper構造函數的參數attention_mechanism_match_lstm,也就是BahdanauAttention的一個對象。expanded_alignments = array_ops.expand_dims(alignments, 1) # [batch_size, 1, ques_size] attention_mechanism_values = self._attention_mechanism.values # [batch_size, ques_size, value_dims] context = math_ops.matmul(expanded_alignments, attention_mechanism_values) # [batch_size, 1, value_dims] context = array_ops.squeeze(context, [1]) # [batch_size, value_dims]通過alignments和attention的Values,計算attend vector,就是對values以alignments為權重求和。
cell_inputs = self._cell_input_fn(inputs, context) #concatenate input with alpha*memory and feed into root LSTM cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state)通過
_cell_input_fn將當前時刻的輸入,和attend vector組合起來,成為當前時刻LSTM的輸入。然后調用LSTM單元計算當前時刻LSTM單元的輸出和狀態。if self._attention_layer is not None: attention = self._attention_layer( array_ops.concat([cell_output, context], 1)) else: attention = context是否需要對attend vector再進行一次線性變換,作為attention,在本例中未做變換,直接用attend vector作為attention。
next_state = AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=alignments, alignment_history=alignment_history)作為RNN的
AttentionWrapper的下一時刻狀態。if self._output_attention: return raw_scores, next_state else: return cell_output, next_state根據構造函數的參數,決定
AttentionWrapper的輸出是attention score還是LSTM的輸出,attention score的意義是求alignments概率之前的那個東西。
Pointer-Net層
以下代碼是Pointer-Net層的邏輯,與Match-LSTM層的邏輯非常接近,但是在一些細節上有所區別。相似的部分是,Pointer-Net層的主體也是通過一個AttentionWrapper完成的,也是組裝了一個LSTM單元和一個BahdanauAttention單元。與Match-LSTM不同的地方是,LSTM單元及BahdanauAttention單元的輸入函數不同,AttentionWrapper的輸出內容不同,並且Pointer-Net層使用一個靜態rnn。
def run_answer_ptr(self, output_attender, masks, labels):
batch_size = tf.shape(output_attender)[0]
masks_question, masks_passage = masks
labels = tf.unstack(labels, axis=1)
#labels = tf.ones([batch_size, 2, 1])
answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question
query_depth_answer_ptr = output_attender.get_shape()[-1]
with tf.variable_scope("answer_ptr_attender"):
attention_mechanism_answer_ptr = BahdanauAttention(query_depth_answer_ptr , output_attender, memory_sequence_length = masks_passage)
# output attention is true because we want to output the attention values
cell_answer_ptr = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True )
answer_ptr_attender = AttentionWrapper(cell_answer_ptr, attention_mechanism_answer_ptr, cell_input_fn = answer_ptr_cell_input_fn)
logits, _ = tf.nn.static_rnn(answer_ptr_attender, labels, dtype = tf.float32)
return logits
接下來具體看一下上面這段代碼。
batch_size = tf.shape(output_attender)[0] # [batch_size, passage_length, 2 * hidden_size] masks_question, masks_passage = masks labels = tf.unstack(labels, axis=1) # labels : [batch_size, 2]
output_attender是上一層,也就是Match-LSTM層的輸出,形狀為[batch_size, passage_length, 2 * hidden_size]。labels的形狀為[batch_size, 2]。masks_question和masks_passage分別為問題的長度和文章的長度。answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question query_depth_answer_ptr = output_attender.get_shape()[-1]
answer_ptr_cell_input_fn定義了AttentionWrapper中LSTM單元的輸入函數。query_depth_answer_ptr從變量名的字面含義看,是Answer-Ptr層的attention單元的query的維度。with tf.variable_scope("answer_ptr_attender"): attention_mechanism_answer_ptr = BahdanauAttention(query_depth_answer_ptr , output_attender, memory_sequence_length = masks_passage) # output attention is true because we want to output the attention values cell_answer_ptr = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True ) answer_ptr_attender = AttentionWrapper(cell_answer_ptr, attention_mechanism_answer_ptr, cell_input_fn = answer_ptr_cell_input_fn)接下來是裝配
AttentionWrapper,這里與Match-LSTM層有區別。在Match-LSTM層的定義中,沒有顯式地為AttentionWrapper指定cell_input_fn參數,而是使用了默認函數。在Match-LSTM層的定義中,顯式指定了attention_input_fn,但是這里沒有指定,使用了默認函數。另外一個區別,在Match-LSTM層的定義中,AttentionWrapper的output_attention參數是False,在這里該參數用默認的True。
對比Match-LSTM層與Pointer-Net層cell_input_fn的區別。
默認的
cell_input_fn的定義如下,這是Match-LSTM層采用的。邏輯是將attention的輸出和當前的輸入拼接起來,作為LSTM單元的輸入。if cell_input_fn is None: cell_input_fn = ( lambda inputs, attention: array_ops.concat([inputs, attention], -1))Pointer-Net層使用的
cell_input_fn在上面的代碼中已經給出,這里對比一下。只用Attention單元的輸出,作為LSTM單元的輸入。這樣,LSTM單元的輸入,就與RNN的輸入無關了。answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question
對比Match-LSTM層與Pointer-Net層attention_input_fn的區別。
Match-LSTM層采用的
attention_input_fn是非默認的,在上一節中已經給出,這里對比一下。match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)Pointer-Net層的
attention_input_fn是默認的,定義如下。if attention_input_fn is None: attention_input_fn = ( lambda _, state: state)可以看出,在Match-LSTM層,
attention單元的輸入是上一時刻狀態與當前輸入的拼接。在Pointer-Net層,attention單元的輸入僅僅是上一時刻的狀態,與當前時刻的輸入無關。
綜上兩處,可以看出區別。在Match-LSTM層,無論Attention單元還是LSTM單元,其輸入都要拼接當前時刻輸入。而在Pointer-Net層,無論Attention單元還是LSTM單元,其輸入都與當前時刻的輸入無關。這也解釋了我最早看代碼時的疑惑,為什么計算logits的函數需要labels作為參數,labels不是只有在計算loss的時候才需要嗎?其實雖然這里有labels這個參數,但是沒有實際使用其內容,對於預測過程,只需傳一個同樣形狀的tensor就可以。
再對比最后一個區別,Match-LSTM層與Pointer-Net層在output_attention參數上的區別。
if self._output_attention: return raw_scores, next_state else: return cell_output, next_state
raw_scores是attention單元的原始輸出,即通過softmax計算alignments前的那個輸出。cell_output是LSTM單元的輸出,也就是狀態h。在Match-LSTM層,AttentionWrapper輸出的是其內部LSTM單元的輸出。在Pointer-Net層,AttentionWrapper輸出的是其內部attention單元的raw_scores。
logits, _ = tf.nn.static_rnn(answer_ptr_attender, labels, dtype = tf.float32)最后是計算
logits。因為labels是個長度為2的list,logits也是長度為2的list。但是,這兩個list中元素的shape是不一樣的,labels中的元素的shape是[batch_size, 1],logits中的元素的shape是[batch_size, passage_length]。
從代碼層面來理解,首先是以zero_state為query去計算attention,attention單元的key和value都是Match-LSTM層的輸出,attention計算的raw_score就是第一個輸出的logit。attention計算出的alignments與values計算attend vector,以其為輸入計算LSTM單元的輸出,作為下一時刻的query去計算attention。這樣,就計算出了兩個logits。
至此,計算出logits,預測部分就已經完成了。logits是一個長度為2的list,其中每個元素是一個shape為[batch_size, passage_length]的tensor。
損失函數
有了logits,就可以計算損失函數了。
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits[0], labels=self.labels[:,0])
losses += tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits[1], labels=self.labels[:,1])
self.loss = tf.reduce_mean(losses)
這里只需要理解一個函數即可tf.nn.sparse_softmax_cross_entropy_with_logits,該函數logits參數的rank比labels多1,多出的那個axis的維度是num_classes。labels以稀疏形式表示,每個元素都是整數,小於num_classes。
由於之前已經知道,Pointer-Net層求出logits是一個list,每個元素的形狀是[batch_size, passage_length],而輸入的labels的形狀是[batch_size, 2]。因此按照上面代碼的方式調用可求出損失函數。
