一、任務背景介紹
本次訓練實戰參照的是該篇博客文章:https://kexue.fm/archives/6933
本次訓練任務采用的是THUCNews的數據集,THUCNews是根據新浪新聞RSS訂閱頻道2005~2011年間的歷史數據篩選過濾生成,包含74萬篇新聞文檔,由多個類別的新聞標題和內容組成。本次任務的目標是利用bert結合Unilm模型的思想來訓練seq2seq模型,輸入由s1和s2兩個segment組成,s1是文章內容,s2是文章標題,在輸入的時候采用mask機制,可以參照之前的Unilm模型里的mask,如下(藍色實框表示可見):
在輸出計算loss的時候,根據segment id只計算生成標題的損失,也就是以標題部分OK為最大目標。
二、模型訓練
1)訓練邏輯示意圖
2)計算損失示意圖
在計算損失時,通過segment id=1控制,只有右側那部分sequence參與損失計算,w1-w6是什么不關心。
三、預測並解碼
1)解碼邏輯示意圖
每次的輸出都會和輸入連接一起作為新的輸入進行預測下一個word,直到遇到end符號或者滿足最大輸出max_len才結束。
2)代碼實現(beam_search)
class AutoTitle(AutoRegressiveDecoder): """seq2seq解碼器 """ def beam_search(self, inputs, topk): """beam search解碼 說明:這里的topk即beam size; 返回:最優解碼序列。 """ inputs = [np.array([i]) for i in inputs] output_ids, output_scores = self.first_output_ids, np.zeros(1) quasi_output, quasi_score = [], -np.inf for step in range(self.maxlen): scores = self.predict(inputs, output_ids, step, 'logits') # 計算當前得分,並把最新的output結果也加進去共同作為輸入。 if step == 0: # 第1步預測后將輸入重復topk次 inputs = [np.repeat(i, topk, axis=0) for i in inputs] scores = output_scores.reshape((-1, 1)) + scores # 計算累積得分,output_scores存的就是之前最大的累計概率,因為是log所以采用相加,相當於乘了 indices = scores.argpartition(-topk, axis=None)[-topk:] # 從最新的累積得分里面再找出tok最大的 indices_1 = indices // scores.shape[1] # 行索引 indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 output_ids = np.concatenate([output_ids[indices_1], indices_2], 1) # 把最新找出來的最大的token_id存放到輸出list里面中 output_scores = np.take_along_axis(scores, indices, axis=None) # 更新累積最大得分,每次存的就是累計的最大得分,也就是概率最大 best_one = output_scores.argmax() # 找出最優的序列,因為output_scores里面可能存多個序列,和tok有關,output_scores存的就是序列累計總概率分 if indices_2[best_one, 0] == self.end_id: # 判斷是否可以輸出 if output_scores[best_one] >= quasi_score: # 跟緩存比較 return output_ids[best_one] # 返回當前最優 else: return quasi_output # 返回緩存的准輸出 else: flag = (indices_2[:, 0] == self.end_id) # 標記已完成序列 if flag.any(): idx = output_scores[flag].argmax() # 准最優序列 quasi_output = output_ids[idx] # 准最優序列 quasi_score = output_scores[idx] # 准最優得分 flag = (flag == False) # 標記未完成序列 inputs = [i[flag] for i in inputs] # 只保留未完成部分輸入 output_ids = output_ids[flag] # 只保留未完成部分候選集 output_scores = output_scores[flag] # 只保留未完成部分候選得分 topk = flag.sum() # 更新topk的值 # 達到長度直接輸出return output_ids[output_scores.argmax()] @AutoRegressiveDecoder.set_rtype('probas') def predict(self, inputs, output_ids, step): token_ids, segment_ids = inputs token_ids = np.concatenate([token_ids, output_ids], 1) segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) return model.predict([token_ids, segment_ids])[:, -1]#每次輸出只留最后一個對應位輸出結果,代表是由前面的輸入生成的一個結果,一個個字生成 def generate(self, text, topk=2): max_c_len = maxlen - self.maxlen token_ids, segment_ids = tokenizer.encode(text, max_length=max_c_len) output_ids = self.beam_search([token_ids, segment_ids], topk) # 基於beam search return tokenizer.decode(output_ids) autotitle = AutoTitle(start_id=None, end_id=tokenizer._token_sep_id, maxlen=32) def just_show(): s1 = u'夏天來臨,皮膚在強烈紫外線的照射下,曬傷不可避免,因此,曬后及時修復顯得尤為重要,否則可能會造成長期傷害。專家表示,選擇曬后護膚品要慎重,蘆薈凝膠是最安全,有效的一種選擇,曬傷嚴重者,還請及 時 就醫 。' s2 = u'8月28日,網絡爆料稱,華住集團旗下連鎖酒店用戶數據疑似發生泄露。從賣家發布的內容看,數據包含華住旗下漢庭、禧玥、桔子、宜必思等10余個品牌酒店的住客信息。泄露的信息包括華住官網注冊資料、酒店入住登記的身份信息及酒店開房記錄,住客姓名、手機號、郵箱、身份證號、登錄賬號密碼等。賣家對這個約5億條數據打包出售。第三方安全平台威脅獵人對信息出售者提供的三萬條數據進行驗證,認為數據真實性非常高。當天下午 ,華 住集 團發聲明稱,已在內部迅速開展核查,並第一時間報警。當晚,上海警方消息稱,接到華住集團報案,警方已經介入調查。' s1 = u'夏天' for s in [s1]: print(u'生成標題:', autotitle.generate(s)) print() just_show()
3) numpy其它輔助函數
#求索引位置的函數
Array.argpartition a = np.array([[7,16,15,90],[6,7,91,9]]) #先對原來的數組進行了排序,輸出的是排序后值得索引位置,比如6最小,所以第一個就是6的索引位置4 a.argpartition(-2, axis=None) #找出top2的索引位置,里面兩個list認為是一個長list構建索引位置的,[-2:]就是取后面最大的兩位 a.argpartition(-2, axis=None)[-2:]
OUT:
array([4, 0, 5, 7, 2, 1, 3, 6], dtype=int64)
array([3, 6], dtype=int64)
#數組合並函數 numpy.concatenate a=np.array([[1,2,3],[4,5,6]]) b=np.array([[6]]).reshape((-1, 1)) c=np.array([0]) #將b合並到a的第c個list里面,1表示按列添加,0表示按行添加 np.concatenate([a[c], b], 1)
OUT:
array([[1, 2, 3, 6]])
#根據索引位置提取值 numpy.take_along_axis a=np.array([[7,8,9,10],[99,100,88,87]]) c=np.array([2,5]) #根據c的值作為索引位置在a中進行查找,a中的兩個list合並為一個長list構建索引位置的 np.take_along_axis(a,c,axis=None) OUT: array([ 9, 100])