一、Masked LM
get_masked_lm_output
函數用於計算「任務#1」的訓練 loss。輸入為 BertModel 的最后一層 sequence_output 輸出([batch_size, seq_length, hidden_size]),先找出輸出結果中masked掉的詞,然后構建一層全連接網絡,接着構建一層節點數為vocab_size的softmax輸出,從而與真實label計算損失。
def get_masked_lm_output(bert_config, input_tensor, #BertModel的最后一層sequence_output輸出model.get_sequence_output()[batch_size, seq_length, hidden_size] output_weights,#輸入是model.get_embedding_table(),[vocab_size,hidden_size] positions, #mask詞的位置 label_ids, #label,真實值結果 label_weights): """Get loss and log probs for the masked LM.""" # 根據positions位置獲取masked詞在Transformer的輸出結果,即要預測的那些位置的encoder input_tensor = gather_indexes(input_tensor, positions)#[batch_size*max_pred_pre_seq,hidden_size] with tf.variable_scope("cls/predictions"): # 在輸出之前添加一個帶激活函數的全連接神經網絡,只在預訓練階段起作用 with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=bert_config.hidden_size, activation=modeling.get_activation(bert_config.hidden_act), kernel_initializer=modeling.create_initializer( bert_config.initializer_range)) input_tensor = modeling.layer_norm(input_tensor) # output_weights是和傳入的word embedding一樣的,這里再添加一個bias output_bias = tf.get_variable( "output_bias", shape=[bert_config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size*max_pred_pre_seq,vocab_size] logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1)#得出masked詞的softmax結果,[batch_size*max_pred_pre_seq,vocab_size] # label_ids表示mask掉的Token的id,下面這部分就是根據真實值計算loss了。 label_ids = tf.reshape(label_ids, [-1])#[batch_size*max_pred_per_seq] label_weights = tf.reshape(label_weights, [-1]) one_hot_labels = tf.one_hot( label_ids, depth=bert_config.vocab_size, dtype=tf.float32)#[batch_size*max_pred_per_seq,vocab_size] # 但是由於實際MASK的可能不到20,比如只MASK18,那么label_ids有2個0(padding),而label_weights=[1, 1, ...., 0, 0],說明后面兩個label_id是padding的,計算loss要去掉,label_weights就是起一個標記作用 per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])#[batch_size*max_pred_per_seq] numerator = tf.reduce_sum(label_weights * per_example_loss) #一個batch的loss denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator #平均loss return (loss, per_example_loss, log_probs)
重要補充:預訓練中的隨機MASK函數
核心思想:每個輸入序列,只有最多15%的token被mask,而其中80%的機會被替換成[MASK],10%的機會保持原詞不變,10%的機會隨機替換為字典中的任意詞。代碼如何實現呢?先獲取每個token的索引位置,然后隨機打亂索引位置,接着取前15%的token進行替換即可。在替換中,再次利用隨機函數,實現80%替換為[MASK]等,代碼層面利用random函數還是比較巧妙的。
def create_masked_lm_predictions(tokens, #list存放的sequence,例如[CLS,今, 天, 舉, 行, 的, 國, 家, 發, 展, 改, 革, 委, 新, 聞, 發, 布, 會, SEP] masked_lm_prob, #代碼中是0.15 max_predictions_per_seq, #代碼中20 vocab_words, rng): #rng=random.Random() cand_indexes = [] # [CLS]和[SEP]不能用於MASK for (i, token) in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue cand_indexes.append(i) #隨機打亂索引順序 rng.shuffle(cand_indexes) output_tokens = list(tokens) #masked token數量,從最大mask配置數和seq長度*mask比例中取一個最小數,作為這個seq最終的mask數量 num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) masked_lms = [] #covered_indexes存放被mask token的索引位置 covered_indexes = set() for index in cand_indexes: #達到mask的數量,就停止 if len(masked_lms) >= num_to_predict: break if index in covered_indexes: continue covered_indexes.add(index) masked_token = None # 80% of the time, replace with [MASK],替換為[MASK] if rng.random() < 0.8: masked_token = "[MASK]" else: # 10% of the time, keep original,保持原詞 if rng.random() < 0.5: masked_token = tokens[index] # 10% of the time, replace with random word,隨機替換 else: masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] #將masked_token替換覆蓋原token output_tokens[index] = masked_token #保存masked token的原索引位置,及真實的label token masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) # 按照下標重排,保證是原來句子中出現的順序 masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lm_positions = [] masked_lm_labels = [] for p in masked_lms: masked_lm_positions.append(p.index) masked_lm_labels.append(p.label) #返回帶mask的sequence tokens,被masked token的原索引位置,及原來的真實label token ,以便計算loss return (output_tokens, masked_lm_positions, masked_lm_labels)
舉例實現隨機替換的思想:
import random,collections MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) #返回的是一個Random對象,每次再調用rng.random()都返回一個0~1的隨機數,這里與bert原代碼保持一致,種子都是12345 rng=random.Random(12345)#這里rng一定要放在函數外面,這樣相當於在外部完成初始化,每次調用函數才會隨機生成不斷變化的結果 def create_mask_sample(sequence="",mask_prob=0.15,vocab_words=[],rng=None): tokens=[] cand_indexes = [] for i,w in enumerate(sequence): cand_indexes.append(i) tokens.append(w) #隨機打亂索引順序 rng.shuffle(cand_indexes) #mask后輸出tokens output_tokens = list(tokens) #一個輸入序列中需要mask的數量 num_to_predict = int(len(tokens)*mask_prob) masked_lms = [] #covered_indexes存放被mask token的索引位置 covered_indexes = set() for index in cand_indexes: #達到mask的數量,就停止 if len(masked_lms) >= num_to_predict: break if index in covered_indexes: continue covered_indexes.add(index) masked_token = None # 80% of the time, replace with [MASK],替換為[MASK] if rng.random() < 0.8: #這里有80%的概率是滿足<0.8 masked_token = "[MASK]" else: #如果是>=0.8情況呢,這里有20%的概率 # 剩下的概率一半保持原詞,也就是10% of the time, keep original,保持原詞 if rng.random() < 0.5: masked_token = tokens[index] # 10% of the time, replace with random word,隨機替換 else: masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] #將masked_token替換覆蓋原token output_tokens[index] = masked_token #保存masked token的原索引位置,及真實的label token masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) # 按照下標重排,保證是原來句子中出現的順序 masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lm_positions = [] masked_lm_labels = [] for p in masked_lms: masked_lm_positions.append(p.index) masked_lm_labels.append(p.label) #返回帶mask的sequence tokens,被masked token的原索引位置,及原來的真實label token ,以便計算loss return (output_tokens, masked_lm_positions, masked_lm_labels)
#舉例子測試
seq='今天下午舉行的市新冠肺炎疫情防控工作領導小組新聞發布會透露:近期,多個國家和地區出現新冠肺炎確診病例,數量持續攀升。鑒於當前境外疫情防控形勢,結合上海實際,市防控工作領導小組及相關部門綜合研判,進一步明確了涉外疫情防控和入境人員健康管理措施。' v_words=['自', '今', '年', '3', '月', '1', '日', '起', ',', '重', '新', '調', '整', '存', '量', '房', '貸', '利', '率', ',', '存', '量', '浮', '動', '利', '率', '貸', '款', '客', '戶', '可', '以', '有', '兩', '個', '選', '擇', ',', '原', '則', '上', '轉', '換', '工', '作', '應', '於', '今', '年', '8', '月', '底', '前', '完', '成', '。', '目', '前', ',', '已', '有', '至', '少', '2', '4', '家', '主', '要', '銀', '行', '發', '布', '了', '相', '關', '公', '告', ',', '多', '家', '銀', '行', '稱', '還', '將', '陸', '續', '發', '送', '一', '對', '一', '短', '信'] output_tokens,masked_lm_positions,masked_lm_labels=create_mask_sample(sequence=seq,mask_prob=0.1,vocab_words=v_words,rng=rng) print(len(output_tokens)) print(''.join(output_tokens)) print(masked_lm_positions) print(masked_lm_labels)
out:
121 今天下午[MASK]行的市新冠肺炎疫情防控工[MASK]領導小組新聞發布會透露:[MASK]期,多個國家和地區出現新冠[MASK]炎確診病例[MASK]數量持[MASK]攀升[MASK]鑒於當[MASK]境外疫情防控形勢,結合上海實際,市防控工[MASK]領導小組及相關部門[MASK]合研判,進一步明[MASK]了涉外疫情防控和入境人員[MASK]康管理措施。 [4, 17, 30, 44, 50, 54, 57, 61, 82, 92, 101, 114] ['舉', '作', '近', '肺', ',', '續', '。', '前', '作', '綜', '確', '健']
注意:同一段話,每調用一次都會隨機生成不同的mask結果,達到隨機mask目的。
二、 Next Sentence Prediction
get_next_sentence_output函數用於計算「任務#2」的訓練 loss,這部分比較簡單,只需要再額外加一層softmax輸出即可。輸入為 BertModel 的最后一層 pooled_output 輸出([batch_size, hidden_size]),因為該任務屬於二分類問題,所以只需要每個序列的第一個 token【CLS】即可。
def get_next_sentence_output(bert_config, input_tensor,#pooled_output 輸出,shape=[batch_size, hidden_size] labels): """Get loss and log probs for the next sentence prediction.""" # 標簽0表示 下一個句子關系成立;標簽1表示 下一個句子關系不成立。這個分類器的參數在實際Fine-tuning階段會丟棄掉 with tf.variable_scope("cls/seq_relationship"): #初始化權重參數,最終的分類結果是只有2個,所以shape=[2,hidden_size] output_weights = tf.get_variable( "output_weights", shape=[2, bert_config.hidden_size], initializer=modeling.create_initializer(bert_config.initializer_range)) output_bias = tf.get_variable( "output_bias", shape=[2], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True)#輸入與權重相乘,shape=[batch_size,2] logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1)#softmax輸出:shape=[batch_size,2] #下面這部分就是根據真實值計算損失loss了 labels = tf.reshape(labels, [-1]) one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) return (loss, per_example_loss, log_probs)