Bert源碼解讀(三)之預訓練部分


一、Masked LM

get_masked_lm_output函數用於計算「任務#1」的訓練 loss。輸入為 BertModel 的最后一層 sequence_output 輸出([batch_size, seq_length, hidden_size]),先找出輸出結果中masked掉的詞,然后構建一層全連接網絡,接着構建一層節點數為vocab_size的softmax輸出,從而與真實label計算損失。

def get_masked_lm_output(bert_config, 
                        input_tensor, #BertModel的最后一層sequence_output輸出model.get_sequence_output()[batch_size, seq_length, hidden_size]
                        output_weights,#輸入是model.get_embedding_table(),[vocab_size,hidden_size]
                           positions, #mask詞的位置
                         label_ids, #label,真實值結果
                         label_weights):
                         
  """Get loss and log probs for the masked LM."""
  # 根據positions位置獲取masked詞在Transformer的輸出結果,即要預測的那些位置的encoder
  input_tensor = gather_indexes(input_tensor, positions)#[batch_size*max_pred_pre_seq,hidden_size]

  with tf.variable_scope("cls/predictions"):
    # 在輸出之前添加一個帶激活函數的全連接神經網絡,只在預訓練階段起作用
    with tf.variable_scope("transform"):
      input_tensor = tf.layers.dense(
          input_tensor,
          units=bert_config.hidden_size,
          activation=modeling.get_activation(bert_config.hidden_act),
          kernel_initializer=modeling.create_initializer(
              bert_config.initializer_range))
      input_tensor = modeling.layer_norm(input_tensor)

    # output_weights是和傳入的word embedding一樣的,這里再添加一個bias
    output_bias = tf.get_variable(
        "output_bias",
        shape=[bert_config.vocab_size],
        initializer=tf.zeros_initializer())
        
    logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size*max_pred_pre_seq,vocab_size]
    logits = tf.nn.bias_add(logits, output_bias)
    log_probs = tf.nn.log_softmax(logits, axis=-1)#得出masked詞的softmax結果,[batch_size*max_pred_pre_seq,vocab_size]

    # label_ids表示mask掉的Token的id,下面這部分就是根據真實值計算loss了。
    label_ids = tf.reshape(label_ids, [-1])#[batch_size*max_pred_per_seq] 
    label_weights = tf.reshape(label_weights, [-1])

    one_hot_labels = tf.one_hot(
        label_ids, depth=bert_config.vocab_size, dtype=tf.float32)#[batch_size*max_pred_per_seq,vocab_size]

    # 但是由於實際MASK的可能不到20,比如只MASK18,那么label_ids有2個0(padding),而label_weights=[1, 1, ...., 0, 0],說明后面兩個label_id是padding的,計算loss要去掉,label_weights就是起一個標記作用
    per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])#[batch_size*max_pred_per_seq] 
    numerator = tf.reduce_sum(label_weights * per_example_loss) #一個batch的loss 
    denominator = tf.reduce_sum(label_weights) + 1e-5
    loss = numerator / denominator  #平均loss 

  return (loss, per_example_loss, log_probs)

重要補充:預訓練中的隨機MASK函數

  核心思想:每個輸入序列,只有最多15%的token被mask,而其中80%的機會被替換成[MASK],10%的機會保持原詞不變,10%的機會隨機替換為字典中的任意詞。代碼如何實現呢?先獲取每個token的索引位置,然后隨機打亂索引位置,接着取前15%的token進行替換即可。在替換中,再次利用隨機函數,實現80%替換為[MASK]等,代碼層面利用random函數還是比較巧妙的。

def create_masked_lm_predictions(tokens, #list存放的sequence,例如[CLS,今, 天, 舉, 行, 的, 國, 家, 發, 展, 改, 革, 委, 新, 聞, 發, 布, 會, SEP]
                                 masked_lm_prob, #代碼中是0.15
                                 max_predictions_per_seq, #代碼中20
                                 vocab_words, 
                                 rng): #rng=random.Random()

  cand_indexes = []
  # [CLS]和[SEP]不能用於MASK
  for (i, token) in enumerate(tokens):
    if token == "[CLS]" or token == "[SEP]":
      continue
    cand_indexes.append(i)
    
  #隨機打亂索引順序
  rng.shuffle(cand_indexes)

  output_tokens = list(tokens)
  #masked token數量,從最大mask配置數和seq長度*mask比例中取一個最小數,作為這個seq最終的mask數量
  num_to_predict = min(max_predictions_per_seq,
                       max(1, int(round(len(tokens) * masked_lm_prob))))
  
  masked_lms = []
  #covered_indexes存放被mask token的索引位置
  covered_indexes = set()
  for index in cand_indexes:
     #達到mask的數量,就停止
    if len(masked_lms) >= num_to_predict:
      break
    if index in covered_indexes:
      continue
    covered_indexes.add(index)

    masked_token = None
    # 80% of the time, replace with [MASK],替換為[MASK]
    if rng.random() < 0.8:
      masked_token = "[MASK]"
    else:
      # 10% of the time, keep original,保持原詞
      if rng.random() < 0.5:
        masked_token = tokens[index]
      # 10% of the time, replace with random word,隨機替換
      else:
        masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
        
    #將masked_token替換覆蓋原token
    output_tokens[index] = masked_token
    
    #保存masked token的原索引位置,及真實的label token
    masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))

  # 按照下標重排,保證是原來句子中出現的順序
  masked_lms = sorted(masked_lms, key=lambda x: x.index)

  masked_lm_positions = []
  masked_lm_labels = []
  for p in masked_lms:
    masked_lm_positions.append(p.index)
    masked_lm_labels.append(p.label)
    
  #返回帶mask的sequence tokens,被masked token的原索引位置,及原來的真實label token ,以便計算loss
  return (output_tokens, masked_lm_positions, masked_lm_labels)

 

舉例實現隨機替換的思想:

import random,collections
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
                                          ["index", "label"])
#返回的是一個Random對象,每次再調用rng.random()都返回一個0~1的隨機數,這里與bert原代碼保持一致,種子都是12345
rng=random.Random(12345)#這里rng一定要放在函數外面,這樣相當於在外部完成初始化,每次調用函數才會隨機生成不斷變化的結果
def create_mask_sample(sequence="",mask_prob=0.15,vocab_words=[],rng=None):
    tokens=[]
    cand_indexes = []
    for i,w in enumerate(sequence):
        cand_indexes.append(i)
        tokens.append(w)
        
    #隨機打亂索引順序
    rng.shuffle(cand_indexes)
    #mask后輸出tokens
    output_tokens = list(tokens)
    #一個輸入序列中需要mask的數量
    num_to_predict = int(len(tokens)*mask_prob)
  
    masked_lms = []
    #covered_indexes存放被mask token的索引位置
    covered_indexes = set()
    for index in cand_indexes:
     #達到mask的數量,就停止
        if len(masked_lms) >= num_to_predict:
            break
        if index in covered_indexes:
            continue
        covered_indexes.add(index)

        masked_token = None
        # 80% of the time, replace with [MASK],替換為[MASK]
        if rng.random() < 0.8:     #這里有80%的概率是滿足<0.8
            masked_token = "[MASK]"
        else:                    #如果是>=0.8情況呢,這里有20%的概率
          # 剩下的概率一半保持原詞,也就是10% of the time, keep original,保持原詞
          if rng.random() < 0.5:
            masked_token = tokens[index]
          # 10% of the time, replace with random word,隨機替換
          else:
            masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
        
        #將masked_token替換覆蓋原token
        output_tokens[index] = masked_token
    
        #保存masked token的原索引位置,及真實的label token
        masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))

    # 按照下標重排,保證是原來句子中出現的順序
    masked_lms = sorted(masked_lms, key=lambda x: x.index)

    masked_lm_positions = []
    masked_lm_labels = []
    for p in masked_lms:
        masked_lm_positions.append(p.index)
        masked_lm_labels.append(p.label)
    
      #返回帶mask的sequence tokens,被masked token的原索引位置,及原來的真實label token ,以便計算loss
    return (output_tokens, masked_lm_positions, masked_lm_labels)

 

#舉例子測試
seq='今天下午舉行的市新冠肺炎疫情防控工作領導小組新聞發布會透露:近期,多個國家和地區出現新冠肺炎確診病例,數量持續攀升。鑒於當前境外疫情防控形勢,結合上海實際,市防控工作領導小組及相關部門綜合研判,進一步明確了涉外疫情防控和入境人員健康管理措施。'
v_words=['', '', '', '3', '', '1', '', '', '', '', '', '調', '',
 '', '', '', '', '', '', '', '', '', '', '', '', '',
 '', '', '', '', '', '', '', '', '', '', '', '', '',
 '', '', '', '', '', '', '', '', '', '', '8', '', '',
 '', '', '', '', '', '', '', '', '', '', '', '2', '4', '',
 '', '', '', '', '', '', '', '', '', '', '', '', '', '',
 '', '', '', '', '', '', '', '', '', '', '', '', '', '']
output_tokens,masked_lm_positions,masked_lm_labels=create_mask_sample(sequence=seq,mask_prob=0.1,vocab_words=v_words,rng=rng)
print(len(output_tokens))
print(''.join(output_tokens))
print(masked_lm_positions)
print(masked_lm_labels)

out:
121
今天下午[MASK]行的市新冠肺炎疫情防控工[MASK]領導小組新聞發布會透露:[MASK]期,多個國家和地區出現新冠[MASK]炎確診病例[MASK]數量持[MASK]攀升[MASK]鑒於當[MASK]境外疫情防控形勢,結合上海實際,市防控工[MASK]領導小組及相關部門[MASK]合研判,進一步明[MASK]了涉外疫情防控和入境人員[MASK]康管理措施。
[4, 17, 30, 44, 50, 54, 57, 61, 82, 92, 101, 114]
['舉', '作', '近', '肺', ',', '續', '。', '前', '作', '綜', '確', '健']

 注意:同一段話,每調用一次都會隨機生成不同的mask結果,達到隨機mask目的。

 

二、 Next Sentence Prediction

get_next_sentence_output函數用於計算「任務#2」的訓練 loss,這部分比較簡單,只需要再額外加一層softmax輸出即可。輸入為 BertModel 的最后一層 pooled_output 輸出([batch_size, hidden_size]),因為該任務屬於二分類問題,所以只需要每個序列的第一個 token【CLS】即可。

def get_next_sentence_output(bert_config,
                            input_tensor,#pooled_output 輸出,shape=[batch_size, hidden_size]
                            labels):
  """Get loss and log probs for the next sentence prediction."""

 # 標簽0表示 下一個句子關系成立;標簽1表示 下一個句子關系不成立。這個分類器的參數在實際Fine-tuning階段會丟棄掉
  with tf.variable_scope("cls/seq_relationship"):
  #初始化權重參數,最終的分類結果是只有2個,所以shape=[2,hidden_size]
    output_weights = tf.get_variable(
        "output_weights",
        shape=[2, bert_config.hidden_size],
        initializer=modeling.create_initializer(bert_config.initializer_range))
    output_bias = tf.get_variable(
        "output_bias", shape=[2], initializer=tf.zeros_initializer())
    
    logits = tf.matmul(input_tensor, output_weights, transpose_b=True)#輸入與權重相乘,shape=[batch_size,2]
    logits = tf.nn.bias_add(logits, output_bias)
    log_probs = tf.nn.log_softmax(logits, axis=-1)#softmax輸出:shape=[batch_size,2]
    
    #下面這部分就是根據真實值計算損失loss了
    labels = tf.reshape(labels, [-1])
    one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)
    return (loss, per_example_loss, log_probs)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM