Beam Search快速理解及代碼解析(上)


Beam Search

簡單介紹一下在文本生成任務中常用的解碼策略Beam Search(集束搜索)。

生成式任務相比普通的分類、tagging等NLP任務會復雜不少。在生成的時候,模型的輸出是一個時間步一個時間步依次獲得的,而且前面時間步的結果還會影響后面時間步的結果。也就是說,每一個時間步,模型給出的都是基於歷史生成結果的條件概率。為了生成完整的句子,需要一個稱為解碼的額外動作來融合模型多個時間步的輸出,而且使得最終得到的序列的每一步條件概率連乘起來最大。

在文本生成任務中,每一個時間步可能的輸出種類稱為字典大小(vocabulary size,我們用V表示),進行T步隨機的生成可能獲得的結果總共有V^T種。拿中文文本生成來說,V 的值大約是5000-6000,即常用漢字的個數。在如此大的基數下,遍歷整個生成空間是不現實的。

貪心搜索

每一個時間步都取出一個條件概率最大的輸出,如圖:

在這里插入圖片描述

Beam Search

思路也很簡單,就是稍微放寬一些考察的范圍。在每一個時間步,不再只保留當前分數最高的1個輸出,而是保留num_beams個。當num_beams=1時集束搜索就退化成了貪心搜索。

在這里插入圖片描述

Beam Search示意圖

  • 在第一個時間步,A和C是最優的兩個,因此得到了兩個結果[A],[C],其他三個就被拋棄了;

  • 第二步會基於這兩個結果繼續進行生成,在A這個分支可以得到5個候選人,[AA],[AB],[AC],[AD],[AE],C也同理得到5個,此時會對這10個進行統一排名,再保留最優的兩個,即圖中的[AB]和[CE];

  • 第三步同理,也會從新的10個候選人里再保留最好的兩個,最后得到了[ABD],[CED]兩個結果。 可以發現,beam search在每一步需要考察的候選人數量是貪心搜索的num_beams倍,因此是一種犧牲時間換性能的方法。

Beam Search代碼解析

Beam Search的原理雖然簡單,但實際實現的時候卻有很多細節要考慮。下面要解析這個實現出自於NLP界著名Python包Transformers[1],我為了說明方便做了一些改動。

一個正確且高效的算法需要處理的問題大概有兩個:

  • 充分利用硬件,可以處理批量數據,且盡量使用並行計算少用循環

  • 處理好長短不同的生成結果

下面是基礎版的beam search函數定義。其中context是編碼器編碼獲得的向量,batch_size是每批數據中包含的樣本量,bos_token_id是句子開頭標志的token id,pad_token_id是用於填充的token id,eos_token_id是句子結束標志的token id。這里給參數填上的默認值和我們后面講解時使用的例子是一致的。

def beam_search_generate(context,
                        batch_size=3,
                        max_length=20,
                        min_length=2,
                        num_beams=2,
                        bos_token_id=101,
                        pad_token_id=0,
                        eos_token_id=102,
                        ):
    pass

 

 

在函數中主要執行以下三個步驟:

  • 准備初始輸入

  • 在當前生成的序列長度未達到max_length時擴展生成序列

  • 准備最終輸出的序列

 

准備初始輸入

# 建立beam容器,每個樣本一個
generated_hyps = [
    BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
    for _ in range(batch_size)
]
 
# 每個beam容器的得分,共batch_size*num_beams個
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=encoder_input_ids.device)
beam_scores = beam_scores.view(-1)
 
# 每個樣本是否完成生成,共batch_size個
done = [False for _ in range(batch_size)]
 
# 為了並行計算,一次生成batch_size*num_beams個序列
# 第一步自動填入bos_token
input_ids = torch.full(
    (batch_size*num_beams, 1),
    bos_token_id,
    dtype=torch.long,
    device=next(self.parameters()).device,
)
 
# 當前長度設為1
cur_len = 1

 

其中BeamHypotheses是一個容器類,每個樣本綁定一個。每個容器中會維護num_beams個當前最優的序列。當往容器中添加一個序列而導致序列數大於num_beams的時候,它會自動踢掉分數最低的那個序列。類代碼如下。

class BeamHypotheses(object):
    def __init__(self, num_beams, max_length, length_penalty):
        self.max_length = max_length - 1   # ignoring bos_token
        self.num_beams = num_beams
        self.beams = []
        self.worst_score = 1e9
 
    def __len__(self):
        return len(self.beams)
 
    def add(self, hyp, sum_logprobs):
        score = sum_logprobs / len(hyp) ** self.length_penalty
        if len(self) < self.num_beams or score > self.worst_score:
            # 可更新的情況:數量未飽和或超過最差得分
            self.beams.append((score, hyp))
            if len(self) > self.num_beams:
                # 數量飽和需要刪掉一個最差的
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)
 
    def is_done(self, best_sum_logprobs, cur_len=None):
        """
        相關樣本是否已經完成生成。
        best_sum_logprobs是新的候選序列中的最高得分。
        """
 
        if len(self) < self.num_beams:
            return False
        else:
            if cur_len is None:
                cur_len = self.max_length
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            # 是否最高分比當前保存的最低分還差
            ret = self.worst_score >= cur_score
            return ret

 

序列擴展

序列擴展是beam search的核心過程,我們特地畫了一張圖來解釋這個版本的實現策略。

img

序列擴展示意圖,下面對照這個圖來講解代碼。

while cur_len < max_length:
    # 將編碼器得到的上下文向量和當前結果輸入解碼器,即圖中1
    output = decoder.decode_next_step(context, input_ids)
    # 輸出矩陣維度為:(batch*num_beams)*cur_len*vocab_size
    
    # 取出最后一個時間步的各token概率,即當前條件概率
    # (batch*num_beams)*vocab_size
    scores = next_token_logits = output[:, -1, :]
 
    ###########################
    # 這里可以做一大堆操作減少重復 #
    ###########################
 
    # 計算序列條件概率的,因為取了log,所以直接相加即可。得到圖中2矩陣
    # (batch_size * num_beams, vocab_size)
    next_scores = scores + beam_scores[:, None].expand_as(scores)
 
    # 為了提速,將結果重排成圖中3的形狀
    next_scores = next_scores.view(
            batch_size, num_beams * vocab_size
        )  # (batch_size, num_beams * vocab_size)
 
    # 取出分數最高的token(圖中黑點)和其對應得分
    # sorted=True,保證返回序列是有序的
    next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
 
    # 下一個時間步整個batch的beam列表
    # 列表中的每一個元素都是三元組
    # (分數, token_id, beam_id)
    next_batch_beam = []
 
    # 對每一個樣本進行擴展
    for batch_idx in range(batch_size):
 
        # 檢查樣本是否已經生成結束
        if done[batch_idx]:
            # 對於已經結束的句子,待添加的是pad token
            next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
            continue
 
        # 當前樣本下一個時間步的beam列表
        next_sent_beam = []
 
        # 對於還未結束的樣本需要找到分數最高的num_beams個擴展
        # 注意,next_scores和next_tokens是對應的
        # 而且已經按照next_scores排好順序
        for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
            zip(next_tokens[batch_idx], next_scores[batch_idx])
        ):
            # get beam and word IDs
            # 這兩行可參考圖中3進行理解
            beam_id = beam_token_id // vocab_size
            token_id = beam_token_id % vocab_size
 
            effective_beam_id = batch_idx * num_beams + beam_id
 
            # 如果出現了EOS token說明已經生成了完整句子
            if (eos_token_id is not None) and (token_id.item() == eos_token_id):
                # if beam_token does not belong to top num_beams tokens, it should not be added
                is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                if is_beam_token_worse_than_top_num_beams:
                    continue
                # 往容器中添加這個序列
                generated_hyps[batch_idx].add(
                    input_ids[effective_beam_id].clone(), beam_token_score.item(),
                )
            else:
                # add next predicted word if it is not eos_token
                next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
 
            # 擴展num_beams個就夠了
            if len(next_sent_beam) == num_beams:
                break
 
        # 檢查這個樣本是否已經生成完了,有兩種情況
        # 1. 已經記錄過該樣本結束
        # 2. 新的結果沒有使結果改善
        done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
            next_scores[batch_idx].max().item(), cur_len=cur_len
        )
 
        # 把當前樣本的結果添加到batch結果的后面
        next_batch_beam.extend(next_sent_beam)
 
    # 如果全部樣本都已經生成結束便可以直接退出了
    if all(done):
        break
    
    # 把三元組列表再還原成三個獨立列表
    beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
    beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
    beam_idx = input_ids.new([x[2] for x in next_batch_beam])
 
    # 准備下一時刻的解碼器輸入
    # 取出實際被擴展的beam
    input_ids = input_ids[beam_idx, :]
    # 在這些beam后面接上新生成的token
    input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
 
    # 更新當前長度
    cur_len = cur_len + 1
    # end of length while

 

准備輸出

上面那個while循環跳出意味着已經生成了長度為max_length的文本,比較理想的情況是所有的句子都已經生成出了eos_token_id,即句子生成結束了。但並不是所有情況都這樣,對於那些”意猶未盡“的樣本,我們需要先手動結束。

# 將未結束的生成結果結束,並置入容器中
for batch_idx in range(batch_size):
    # 已經結束的樣本不需處理
    if done[batch_idx]:
        continue# 把結果加入到generated_hyps容器
    for beam_id in range(num_beams):
        effective_beam_id = batch_idx * num_beams + beam_id
        final_score = beam_scores[effective_beam_id].item()
        final_tokens = input_ids[effective_beam_id]
        generated_hyps[batch_idx].add(final_tokens,final_score)

 

經過上面的處理,所有生成好的句子都已經保存在generated_hyps容器中,每個容器內保存着num_beams個序列,最后就是輸出期望個數的句子。

# select the best hypotheses,最終輸出
# 每個樣本返回幾個句子
output_num_return_sequences_per_batch = 1
# 記錄每個返回句子的長度,用於后面pad
sent_lengths = input_ids.new(output_batch_size)
best = []
​
# 對每個樣本取出最好的output_num_return_sequences_per_batch個句子
for i, hypotheses in enumerate(generated_hyps):
    sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
    for j in range(output_num_return_sequences_per_batch):
        effective_batch_idx = output_num_return_sequences_per_batch * i + j
        best_hyp = sorted_hyps.pop()[1]
        sent_lengths[effective_batch_idx] = len(best_hyp)
        best.append(best_hyp)
​
# 如果長短不一則pad句子,使得最后返回結果的長度一樣
if sent_lengths.min().item() != sent_lengths.max().item():
    sent_max_len = min(sent_lengths.max().item() + 1, max_length)
    # 先把輸出矩陣填滿PAD token
    decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
​
    # 填入真正的內容
    for i, hypo in enumerate(best):
        decoded[i, : sent_lengths[i]] = hypo
        # 填上eos token
        if sent_lengths[i] < max_length:
            decoded[i, sent_lengths[i]] = eos_token_id
else:
    # 所有生成序列都還沒結束,直接堆疊即可
    decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
​
# 返回的結果包含BOS token
return decoded

 

總結

好了,上面就是最基礎的beam search算法。這樣生成出來的結果已經會比貪心搜索好一些,但還是會遇到諸如詞語重復這樣的問題。其實已經有很多針對重復問題的研究,還有下篇。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM