QA系統Match-LSTM代碼研讀
背景
在QA模型中,Match-LSTM是較早提出的,使用Prt-Net邊界模型。本文是對閱讀其實現代碼的總結。主要思路是對照着論文和代碼,對論文中模型的關鍵結構,查看代碼中的具體實現。參考代碼是MurtyShikhar實現的。
模型簡介
模型的輸入是(Passage, Question),模型的輸出是(start_idx, end_idx)。對於輸入,Passage是QA任務中的正文,輸入給模型時已經轉化為經過Padding的id-list;Question是QA任務中的問題,輸入給模型時已經轉化為經過Padding的id-list。對於輸出,start_idx是答案在正文的起始位置,end_idx是答案在正文的結束位置。
用於QA的Match-LSTM模型主要由三層構成:
- LSTM預處理層。
分別將Passage和Question通過LSTM進行處理,使每個位置的表示都帶有一些上下文信息。- Match-LSTM層。
Match-LSTM最早用於文本蘊含,輸入一個前提,一個猜測,判斷前提是否能蘊含猜測。在用於QA任務時,Question被當做前提,Passage被當做猜測。依次處理Passage的每個位置,計算Passage每個位置對Question的Attention,進而求出對Question的Attend Vector。該Attend Vector與第一層的輸出拼接起來,輸入給一個LSTM進行處理,這整個流程被稱作Match-LSTM。
其中Attention選擇BahdanauAttention,Attention的輸入(Query)由上一時刻Match-LSTM的輸出及Passage在當前位置的表示拼接,Attention的key是Question每個位置的表示,Attention的value也是Question每個位置的表示。根據Attention的alignment對Attention Value加權求和計算出Attend Vector。
所以,Match-LSTM本質上由一個LSTM單元和一個Attention單元組成。LSTM單元的輸出作為Match-LSTM層的輸出,LSTM單元的狀態和下一個位置的輸入拼接起來作為Attention單元的輸入(Query),Attention單元的輸出(Attend Vector)與當前位置的輸入拼接起來作為LSTM單元的輸入。也可以理解為在LSTM的基礎上增加Attention,改變LSTM的輸入,在LSTM的原始輸入上增加當前位置對於Question的Attention。- Pointer-Net層。
Pointer-Net層在代碼實現上,與Match-LSTM十分接近。只在涉及輸入、輸出的地方有幾處不同。從原理上看,Pointer-Net層也是一個序列化迭代的Attention過程,首先用zero_state作為query對Match-LSTM層的所有輸出計算attention,作為回答第一個符號的logit。然后以AttentionWrapper的輸出作為下一時刻的query,對Match-LSTM層的所有輸出計算attention,如此迭代進行。對於邊界模型,秩序計算start_index和end_index,這個迭代過程秩序進行兩次。
接下來的幾部分對照論文及代碼中模型關鍵結構實現。
模型
模型圖構建的入口在qa_model.py文件中class QASystem
類的def setup_system(self)
方法內。這一節主要就是對該方法的細節展開解讀。
LSTM預處理層
所有邏輯都包含在qa_model.py文件中,入口位於class QASystem
類的def setup_system(self)
方法內,具體邏輯位於class Encoder
的def encode(self, inputs, masks, encoder_state_input = None)
方法內。
在def setup_system(self)
方法內,通過以下語句調用class Encoder
的def encode(self, inputs, masks, encoder_state_input = None)
方法。
encoder = self.encoder
decoder = self.decoder
encoded_question, encoded_passage, q_rep, p_rep = encoder.encode([self.question, self.passage],
[self.question_lengths, self.passage_lengths], encoder_state_input = None)
再看一下encode
方法的實現。
def encode(self, inputs, masks, encoder_state_input = None):
"""
:param inputs: vector representations of question and passage (a tuple)
:param masks: masking sequences for both question and passage (a tuple)
:param encoder_state_input: (Optional) pass this as initial hidden state to tf.nn.dynamic_rnn to build conditional representations
:return: an encoded representation of the question and passage.
"""
question, passage = inputs
masks_question, masks_passage = masks
# read passage conditioned upon the question
with tf.variable_scope("encoded_question"):
lstm_cell_question = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
encoded_question, (q_rep, _) = tf.nn.dynamic_rnn(lstm_cell_question, question, masks_question, dtype=tf.float32) # (-1,
Q, H)
with tf.variable_scope("encoded_passage"):
lstm_cell_passage = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
encoded_passage, (p_rep, _) = tf.nn.dynamic_rnn(lstm_cell_passage, passage, masks_passage, dtype=tf.float32) # (-1, P,
H)
# outputs beyond sequence lengths are masked with 0s
return encoded_question, encoded_passage , q_rep, p_rep
從代碼可以看出,對Passage和Question的預處理就是分別經過兩個單向LSTM層(不共享參數),LSTM每個位置的輸出作為預處理后的表示。
Match-LSTM層
Match-LSTM的邏輯主要在qa_model.py和attention_wrapper.py兩個文件中。雖然tensorflow的contrib庫中現在也有attention_wrapper這個模塊,但是兩者在具體實現上不太相同。入口位於qa_model.py文件class Decoder
類中decode
方法內。
首先,看一下最外層的入口,與LSTM預處理層一樣,位於class QASystem
類的def setup_system(self)
方法內。
if self.config.use_match:
self.logger.info("\n========Using Match LSTM=========\n")
logits= decoder.decode([encoded_question, encoded_passage], q_rep, [self.question_lengths, self.passage_lengths], self.
labels)
接下來,進入class Decoder
類中decode
方法。函數邏輯非常清晰,先通過Match-LSTM層,再通過Ptr-Net層。
def decode(self, encoded_rep, q_rep, masks, labels):
output_attender = self.run_match_lstm(encoded_rep, masks)
logits = self.run_answer_ptr(output_attender, masks, labels)
return logits
然后進入run_match_lstm
方法。
def run_match_lstm(self, encoded_rep, masks):
encoded_question, encoded_passage = encoded_rep
masks_question, masks_passage = masks
match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)
query_depth = encoded_question.get_shape()[-1]
# output attention is false because we want to output the cell output and not the attention values
with tf.variable_scope("match_lstm_attender"):
attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)
cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
lstm_attender = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)
# we don't mask the passage because masking the memories will be handled by the pointerNet
reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0)
output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn")
output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn")
output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)
output_attender = tf.concat([output_attender_fw, output_attender_bw], axis = -1) # (-1, P, 2*H)
return output_attender
該方法的輸入
encoded_rep
是一個tuple
,包含Passage和Question的表示;masks
也是一個tuple
,包含Passage和Question的長度。match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)
這條語句定義了
Match-LSTM
單元中AttentionMechanism
的輸入函數,作為參數該函數被傳遞給AttentionWrapper
的構造函數,作為attention_input_fn
。AttentionWrapper
本身也是一個RNN
,它組合了一個RNN
和一個AttentionMechanism
,形成一個高級的RNN
單元。該函數就是定義了用於Attention機制的Query是如何生成的,由當前時刻的輸入拼接上一個時刻的state,形成Attention的Query。attention_mechanism_match_lstm = BahdanauAttention(query_depth, encoded_question, memory_sequence_length = masks_question)
這條語句定義了一個
AttentionMechanism
,也就是一個Attention單元,該類包含一個__call__
方法,調用該對象可以計算出alignments
,調用該類對象的參數如方法定義所示def __call__(self, query, previous_alignments)
。聯系上面一起來看,這里的query
就是上面所說的Attention的Query。
至於BahdanauAttention
是如何實現的,暫時不做過詳細的介紹,目前該類位於tf.contrib.seq2seq.BahdanauAttention
,已經是tensorflow庫的一部分。cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True)
這條語句定義一個普通的LSTM單元。
lstm_attender = AttentionWrapper(cell, attention_mechanism_match_lstm, output_attention = False, attention_input_fn = match_lstm_cell_attention_fn)
這條語句將上面兩步定義的
AttentionMechanism
及LSTM單元組裝為一個高級RNN單元。參數還包括了在run_match_lstm
方法一開頭頂一個的一個函數,該函數用來生成AttentionMechanism
的query
。reverse_encoded_passage = _reverse(encoded_passage, masks_passage, 1, 0) output_attender_fw, _ = tf.nn.dynamic_rnn(lstm_attender, encoded_passage, dtype=tf.float32, scope ="rnn") output_attender_bw, _ = tf.nn.dynamic_rnn(lstm_attender, reverse_encoded_passage, dtype=tf.float32, scope = "rnn") output_attender_bw = _reverse(output_attender_bw, masks_passage, 1, 0)
分別正向、反向對Passage的表示應用Match-LSTM,再將輸出沿最后一個維度拼接起來作為Match-LSTM層的輸出。
我們還可以再近距離看一下LSTM單元和AttentionMechanism
是如何配合工作的,這需要深入到AttentionWrapper
的call
方法,這也是所有RNN單元都需要實現的一個方法。
def call(self, inputs, state):
output_prev_step = state.cell_state.h # get hr_(i-1)
attention_input = self._attention_input_fn(inputs, output_prev_step) # get input to BahdanauAttention to get alpha_i
alignments, raw_scores = self._attention_mechanism(
attention_input, previous_alignments=state.alignments)
expanded_alignments = array_ops.expand_dims(alignments, 1)
attention_mechanism_values = self._attention_mechanism.values
context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
context = array_ops.squeeze(context, [1])
cell_inputs = self._cell_input_fn(inputs, context) #concatenate input with alpha*memory and feed into root LSTM
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
if self._attention_layer is not None:
attention = self._attention_layer(
array_ops.concat([cell_output, context], 1))
else:
attention = context
if self._alignment_history:
alignment_history = state.alignment_history.write(
state.time, alignments)
else:
alignment_history = ()
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
alignments=alignments,
alignment_history=alignment_history)
if self._output_attention:
return raw_scores, next_state
else:
return cell_output, next_state
output_prev_step = state.cell_state.h # get hr_(i-1) attention_input = self._attention_input_fn(inputs, output_prev_step)
取LSTM單元上一時刻的狀態,與
AttentionWrapper
當前時刻的輸入,通過self._attention_input_fn
函數生成attention的Query。這里的self._attention_input_fn
就是上面AttentionWrapper
構造函數的參數attention_input_fn
。alignments, raw_scores = self._attention_mechanism(attention_input, previous_alignments=state.alignments)
調用
AttentionMechaism
對象,計算Attention的alignments。這里的self._attention_mechanism
就是AttentionWrapper
構造函數的參數attention_mechanism_match_lstm
,也就是BahdanauAttention
的一個對象。expanded_alignments = array_ops.expand_dims(alignments, 1) # [batch_size, 1, ques_size] attention_mechanism_values = self._attention_mechanism.values # [batch_size, ques_size, value_dims] context = math_ops.matmul(expanded_alignments, attention_mechanism_values) # [batch_size, 1, value_dims] context = array_ops.squeeze(context, [1]) # [batch_size, value_dims]
通過alignments和attention的Values,計算attend vector,就是對values以alignments為權重求和。
cell_inputs = self._cell_input_fn(inputs, context) #concatenate input with alpha*memory and feed into root LSTM cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
通過
_cell_input_fn
將當前時刻的輸入,和attend vector組合起來,成為當前時刻LSTM的輸入。然后調用LSTM單元計算當前時刻LSTM單元的輸出和狀態。if self._attention_layer is not None: attention = self._attention_layer( array_ops.concat([cell_output, context], 1)) else: attention = context
是否需要對attend vector再進行一次線性變換,作為attention,在本例中未做變換,直接用attend vector作為attention。
next_state = AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=alignments, alignment_history=alignment_history)
作為RNN的
AttentionWrapper
的下一時刻狀態。if self._output_attention: return raw_scores, next_state else: return cell_output, next_state
根據構造函數的參數,決定
AttentionWrapper
的輸出是attention score還是LSTM的輸出,attention score的意義是求alignments概率之前的那個東西。
Pointer-Net層
以下代碼是Pointer-Net層的邏輯,與Match-LSTM層的邏輯非常接近,但是在一些細節上有所區別。相似的部分是,Pointer-Net層的主體也是通過一個AttentionWrapper
完成的,也是組裝了一個LSTM
單元和一個BahdanauAttention
單元。與Match-LSTM不同的地方是,LSTM
單元及BahdanauAttention
單元的輸入函數不同,AttentionWrapper
的輸出內容不同,並且Pointer-Net層使用一個靜態rnn。
def run_answer_ptr(self, output_attender, masks, labels):
batch_size = tf.shape(output_attender)[0]
masks_question, masks_passage = masks
labels = tf.unstack(labels, axis=1)
#labels = tf.ones([batch_size, 2, 1])
answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question
query_depth_answer_ptr = output_attender.get_shape()[-1]
with tf.variable_scope("answer_ptr_attender"):
attention_mechanism_answer_ptr = BahdanauAttention(query_depth_answer_ptr , output_attender, memory_sequence_length = masks_passage)
# output attention is true because we want to output the attention values
cell_answer_ptr = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True )
answer_ptr_attender = AttentionWrapper(cell_answer_ptr, attention_mechanism_answer_ptr, cell_input_fn = answer_ptr_cell_input_fn)
logits, _ = tf.nn.static_rnn(answer_ptr_attender, labels, dtype = tf.float32)
return logits
接下來具體看一下上面這段代碼。
batch_size = tf.shape(output_attender)[0] # [batch_size, passage_length, 2 * hidden_size] masks_question, masks_passage = masks labels = tf.unstack(labels, axis=1) # labels : [batch_size, 2]
output_attender
是上一層,也就是Match-LSTM層的輸出,形狀為[batch_size, passage_length, 2 * hidden_size]
。labels
的形狀為[batch_size, 2]
。masks_question
和masks_passage
分別為問題的長度和文章的長度。answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question query_depth_answer_ptr = output_attender.get_shape()[-1]
answer_ptr_cell_input_fn
定義了AttentionWrapper
中LSTM
單元的輸入函數。query_depth_answer_ptr
從變量名的字面含義看,是Answer-Ptr層的attention單元的query的維度。with tf.variable_scope("answer_ptr_attender"): attention_mechanism_answer_ptr = BahdanauAttention(query_depth_answer_ptr , output_attender, memory_sequence_length = masks_passage) # output attention is true because we want to output the attention values cell_answer_ptr = tf.contrib.rnn.BasicLSTMCell(self.hidden_size, state_is_tuple = True ) answer_ptr_attender = AttentionWrapper(cell_answer_ptr, attention_mechanism_answer_ptr, cell_input_fn = answer_ptr_cell_input_fn)
接下來是裝配
AttentionWrapper
,這里與Match-LSTM層有區別。在Match-LSTM層的定義中,沒有顯式地為AttentionWrapper
指定cell_input_fn
參數,而是使用了默認函數。在Match-LSTM層的定義中,顯式指定了attention_input_fn
,但是這里沒有指定,使用了默認函數。另外一個區別,在Match-LSTM層的定義中,AttentionWrapper
的output_attention
參數是False
,在這里該參數用默認的True
。
對比Match-LSTM層與Pointer-Net層cell_input_fn
的區別。
默認的
cell_input_fn
的定義如下,這是Match-LSTM層采用的。邏輯是將attention的輸出和當前的輸入拼接起來,作為LSTM
單元的輸入。if cell_input_fn is None: cell_input_fn = ( lambda inputs, attention: array_ops.concat([inputs, attention], -1))
Pointer-Net層使用的
cell_input_fn
在上面的代碼中已經給出,這里對比一下。只用Attention單元的輸出,作為LSTM
單元的輸入。這樣,LSTM
單元的輸入,就與RNN的輸入無關了。answer_ptr_cell_input_fn = lambda curr_input, context : context # independent of question
對比Match-LSTM層與Pointer-Net層attention_input_fn
的區別。
Match-LSTM層采用的
attention_input_fn
是非默認的,在上一節中已經給出,這里對比一下。match_lstm_cell_attention_fn = lambda curr_input, state : tf.concat([curr_input, state], axis = -1)
Pointer-Net層的
attention_input_fn
是默認的,定義如下。if attention_input_fn is None: attention_input_fn = ( lambda _, state: state)
可以看出,在Match-LSTM層,
attention
單元的輸入是上一時刻狀態與當前輸入的拼接。在Pointer-Net層,attention
單元的輸入僅僅是上一時刻的狀態,與當前時刻的輸入無關。
綜上兩處,可以看出區別。在Match-LSTM層,無論Attention
單元還是LSTM
單元,其輸入都要拼接當前時刻輸入。而在Pointer-Net層,無論Attention
單元還是LSTM
單元,其輸入都與當前時刻的輸入無關。這也解釋了我最早看代碼時的疑惑,為什么計算logits
的函數需要labels
作為參數,labels
不是只有在計算loss
的時候才需要嗎?其實雖然這里有labels
這個參數,但是沒有實際使用其內容,對於預測過程,只需傳一個同樣形狀的tensor就可以。
再對比最后一個區別,Match-LSTM層與Pointer-Net層在output_attention參數上的區別。
if self._output_attention: return raw_scores, next_state else: return cell_output, next_state
raw_scores
是attention
單元的原始輸出,即通過softmax
計算alignments
前的那個輸出。cell_output
是LSTM
單元的輸出,也就是狀態h
。在Match-LSTM層,AttentionWrapper
輸出的是其內部LSTM
單元的輸出。在Pointer-Net層,AttentionWrapper
輸出的是其內部attention
單元的raw_scores
。
logits, _ = tf.nn.static_rnn(answer_ptr_attender, labels, dtype = tf.float32)
最后是計算
logits
。因為labels
是個長度為2的list
,logits
也是長度為2的list
。但是,這兩個list
中元素的shape
是不一樣的,labels
中的元素的shape
是[batch_size, 1]
,logits
中的元素的shape
是[batch_size, passage_length]
。
從代碼層面來理解,首先是以zero_state
為query去計算attention,attention
單元的key和value都是Match-LSTM層的輸出,attention
計算的raw_score
就是第一個輸出的logit
。attention
計算出的alignments
與values
計算attend vector
,以其為輸入計算LSTM
單元的輸出,作為下一時刻的query去計算attention。這樣,就計算出了兩個logits
。
至此,計算出logits
,預測部分就已經完成了。logits
是一個長度為2的list
,其中每個元素是一個shape
為[batch_size, passage_length]
的tensor。
損失函數
有了logits
,就可以計算損失函數了。
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits[0], labels=self.labels[:,0])
losses += tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits[1], labels=self.labels[:,1])
self.loss = tf.reduce_mean(losses)
這里只需要理解一個函數即可tf.nn.sparse_softmax_cross_entropy_with_logits
,該函數logits
參數的rank
比labels
多1,多出的那個axis
的維度是num_classes
。labels
以稀疏形式表示,每個元素都是整數,小於num_classes
。
由於之前已經知道,Pointer-Net層求出logits
是一個list
,每個元素的形狀是[batch_size, passage_length]
,而輸入的labels
的形狀是[batch_size, 2]
。因此按照上面代碼的方式調用可求出損失函數。