基於keras4bert的seq2seq機制的文章標題生成


一、任務背景介紹

本次訓練實戰參照的是該篇博客文章:https://kexue.fm/archives/6933

本次訓練任務采用的是THUCNews的數據集,THUCNews是根據新浪新聞RSS訂閱頻道2005~2011年間的歷史數據篩選過濾生成,包含74萬篇新聞文檔,由多個類別的新聞標題和內容組成。本次任務的目標是利用bert結合Unilm模型的思想來訓練seq2seq模型,輸入由s1和s2兩個segment組成,s1是文章內容,s2是文章標題,在輸入的時候采用mask機制,可以參照之前的Unilm模型里的mask,如下(藍色實框表示可見):

在輸出計算loss的時候,根據segment  id只計算生成標題的損失,也就是以標題部分OK為最大目標。

 

二、模型訓練

1)訓練邏輯示意圖

 

 

2)計算損失示意圖

在計算損失時,通過segment id=1控制,只有右側那部分sequence參與損失計算,w1-w6是什么不關心。

 

三、預測並解碼

1)解碼邏輯示意圖

 

每次的輸出都會和輸入連接一起作為新的輸入進行預測下一個word,直到遇到end符號或者滿足最大輸出max_len才結束。

 

2)代碼實現(beam_search)

class AutoTitle(AutoRegressiveDecoder):
    """seq2seq解碼器
    """
    def beam_search(self, inputs, topk):
        """beam search解碼
        說明:這里的topk即beam size;
        返回:最優解碼序列。
        """
        inputs = [np.array([i]) for i in inputs]
        output_ids, output_scores = self.first_output_ids, np.zeros(1)
        quasi_output, quasi_score = [], -np.inf
        for step in range(self.maxlen):
            scores = self.predict(inputs, output_ids, step, 'logits')  # 計算當前得分,並把最新的output結果也加進去共同作為輸入。
            if step == 0:  # 第1步預測后將輸入重復topk次
                inputs = [np.repeat(i, topk, axis=0) for i in inputs]
        
            scores = output_scores.reshape((-1, 1)) + scores  # 計算累積得分,output_scores存的就是之前最大的累計概率,因為是log所以采用相加,相當於乘了
            indices = scores.argpartition(-topk, axis=None)[-topk:]  # 從最新的累積得分里面再找出tok最大的
            indices_1 = indices // scores.shape[1]  # 行索引
            indices_2 = (indices % scores.shape[1]).reshape((-1, 1))  # 列索引
            output_ids = np.concatenate([output_ids[indices_1], indices_2], 1)  # 把最新找出來的最大的token_id存放到輸出list里面中
            output_scores = np.take_along_axis(scores, indices, axis=None)  # 更新累積最大得分,每次存的就是累計的最大得分,也就是概率最大
            
            best_one = output_scores.argmax()  # 找出最優的序列,因為output_scores里面可能存多個序列,和tok有關,output_scores存的就是序列累計總概率分            
            if indices_2[best_one, 0] == self.end_id:  # 判斷是否可以輸出
                if output_scores[best_one] >= quasi_score:  # 跟緩存比較
                    return output_ids[best_one]  # 返回當前最優
                else:
                    return quasi_output  # 返回緩存的准輸出
            else:
                flag = (indices_2[:, 0] == self.end_id)  # 標記已完成序列
                if flag.any():
                    idx = output_scores[flag].argmax()  # 准最優序列
                    quasi_output = output_ids[idx]  # 准最優序列
                    quasi_score = output_scores[idx]  # 准最優得分
                    flag = (flag == False)  # 標記未完成序列
                    inputs = [i[flag] for i in inputs]  # 只保留未完成部分輸入
                    output_ids = output_ids[flag]  # 只保留未完成部分候選集
                 
                    output_scores = output_scores[flag]  # 只保留未完成部分候選得分
                    topk = flag.sum()  # 更新topk的值
        # 達到長度直接輸出return output_ids[output_scores.argmax()]
    
    @AutoRegressiveDecoder.set_rtype('probas')
    def predict(self, inputs, output_ids, step):
        token_ids, segment_ids = inputs
        token_ids = np.concatenate([token_ids, output_ids], 1)
        segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
        return model.predict([token_ids, segment_ids])[:, -1]#每次輸出只留最后一個對應位輸出結果,代表是由前面的輸入生成的一個結果,一個個字生成

    def generate(self, text, topk=2):
        max_c_len = maxlen - self.maxlen
        token_ids, segment_ids = tokenizer.encode(text, max_length=max_c_len)
        output_ids = self.beam_search([token_ids, segment_ids], topk)  # 基於beam search
        
        return tokenizer.decode(output_ids)


autotitle = AutoTitle(start_id=None,
                      end_id=tokenizer._token_sep_id,
                      maxlen=32)

def just_show():
    s1 = u'夏天來臨,皮膚在強烈紫外線的照射下,曬傷不可避免,因此,曬后及時修復顯得尤為重要,否則可能會造成長期傷害。專家表示,選擇曬后護膚品要慎重,蘆薈凝膠是最安全,有效的一種選擇,曬傷嚴重者,還請及 時 就醫 。'
    s2 = u'8月28日,網絡爆料稱,華住集團旗下連鎖酒店用戶數據疑似發生泄露。從賣家發布的內容看,數據包含華住旗下漢庭、禧玥、桔子、宜必思等10余個品牌酒店的住客信息。泄露的信息包括華住官網注冊資料、酒店入住登記的身份信息及酒店開房記錄,住客姓名、手機號、郵箱、身份證號、登錄賬號密碼等。賣家對這個約5億條數據打包出售。第三方安全平台威脅獵人對信息出售者提供的三萬條數據進行驗證,認為數據真實性非常高。當天下午 ,華 住集 團發聲明稱,已在內部迅速開展核查,並第一時間報警。當晚,上海警方消息稱,接到華住集團報案,警方已經介入調查。'
    s1 = u'夏天'
    for s in [s1]:
        print(u'生成標題:', autotitle.generate(s))
    print()
just_show()

 

3) numpy其它輔助函數

#求索引位置的函數
Array.argpartition
a = np.array([[7,16,15,90],[6,7,91,9]]) #先對原來的數組進行了排序,輸出的是排序后值得索引位置,比如6最小,所以第一個就是6的索引位置4 a.argpartition(-2, axis=None) #找出top2的索引位置,里面兩個list認為是一個長list構建索引位置的,[-2:]就是取后面最大的兩位 a.argpartition(-2, axis=None)[-2:]

OUT:

   array([4, 0, 5, 7, 2, 1, 3, 6], dtype=int64)
   array([3, 6], dtype=int64)

#數組合並函數
numpy.concatenate

a=np.array([[1,2,3],[4,5,6]])
b=np.array([[6]]).reshape((-1, 1))
c=np.array([0])
#將b合並到a的第c個list里面,1表示按列添加,0表示按行添加
np.concatenate([a[c], b], 1)

OUT:
array([[1, 2, 3, 6]])
#根據索引位置提取值
numpy.take_along_axis

a=np.array([[7,8,9,10],[99,100,88,87]])
c=np.array([2,5])
#根據c的值作為索引位置在a中進行查找,a中的兩個list合並為一個長list構建索引位置的
np.take_along_axis(a,c,axis=None)


OUT:
array([  9, 100])

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM