CTC 解碼算法之 prefix beam search


ctc prefix beam search 算法

CTC 網絡的輸出 net_out 形狀為 T × C ,其中 T 是時間長度, C 是字符類別數加1(額外的blank)。
CTC 的 beam search 算法維護的不是 K 個路徑前綴,而是 K 個標簽前綴,但仍需要考慮其背后的路徑(路徑到標簽的多對一關系)。每個時間步,對 K 個前綴進行擴展,用字符表中的字符對已有前綴做擴展,得到新的多個前綴,然后計算這些前綴的概率,從中挑選出概率最大的 K 個保存,不斷重復這個過程直到最后一個時間步,然后選出概率最大的一個結果作為最終的標簽。

第一種實現

# -*-coding:utf-8-*-
from collections import defaultdict, Counter
from string import ascii_lowercase
import re
import numpy as np


def prefix_beam_search(ctc, lm=None, k=25, alpha=0.30, beta=5, prune=0.001):
    """
    對CTC網絡的輸出做beam search
    Args:
            ctc (np.ndarray): The CTC網絡輸出. 2D array形狀為(timesteps x alphabet_size)
            lm (func): 語言模型函數. 接收一個字符串做參數,輸出一個概率
            k (int): beam寬度. 每個時間步將保存 k 個概率最大的候選前綴
            alpha (float): 語言模型的權重,取值在0到1
            beta (float): 語言模型的懲罰(獎勵)項權重. alpha越大,beta應該越大
            prune (float): ctc的每個時間步的輸出分布中概率大於prune的才參與前綴擴展

    Retruns:
            string: 返回解碼結果
    """

    # 沒有提供語言模型則始終返回1
    lm = (lambda l: 1) if lm is None else lm
    # 正則匹配l中所有的單詞
    W = (lambda l: re.findall(r'\w+[\s|>]', l))
    # 字母表是小寫英文字母及空格、結束符、空白標簽
    alphabet = list(ascii_lowercase) + [' ', '>', '%']
    F = ctc.shape[1]
    # 對ctc輸出添加一個想象中的時間步0,用於初始化空,手動傳個blank進來
    ctc = np.vstack((np.zeros(F), ctc))
    T = ctc.shape[0]

    # 空前綴
    origin = ''
    # 每個時間步下beam里雖然存的是前綴集合,但需要考慮其背后對應的路徑集合
    # Pb表示前綴由那些以blank結尾的路徑生成的概率,Pnb表示前綴由那些以non-blank結尾的路徑生成的概率
    Pb, Pnb = defaultdict(Counter), defaultdict(Counter)
    # 因為手動傳blank進來,所以以blank結尾且前綴為空的概率為1
    Pb[0][origin] = 1
    # non-blank結尾的路徑生成空前綴,不可能發生
    Pnb[0][origin] = 0
    # A_prev保存當前時間步開始擴展之前所保留的概率最大的候選前綴集合,數目小於等於 k
    A_prev = [origin]
    # 不斷的擴展前綴,始終保留概率最大的 k 個
    # 路徑空間上的事件定義:
    # A(t, l)代表到t步為止生成l,
    # Ab(t, l)代表到t步為止生成l且末位為blank,
    # Anb(t, l)代表到t步為止生成l且末位為non-blank
    for t in range(1, T):
        # 對當前時間步的字母表分布,選取概率大於prune的,減少運算
        pruned_alphabet = [alphabet[i] for i in np.where(ctc[t] > prune)[0]]
        # 因為同一個時間步A_prev里不同的前綴可能會擴展出相同的新前綴,所以概率用增量式而不是賦值
        for l in A_prev:
            # 當前前綴已經到句末了,不能再擴展
            if len(l) > 0 and l[-1] == '>':
                Pb[t][l] = Pb[t - 1][l]
                Pnb[t][l] = Pnb[t - 1][l]
                continue
            # 每個l都代表着A(t-1, l)事件
            for c in pruned_alphabet:
                c_ix = alphabet.index(c)
                # A(t-1, l)遇到blank,只有一種結果,即Ab(t, l)
                # 計算概率貢獻P(t-1, l) * P(blank, t)
                if c == '%':
                    Pb[t][l] += ctc[t][-1] * (Pb[t - 1][l] + Pnb[t - 1][l])
                else:
                    l_plus = l + c
                    # l中有多種路徑來源,遇到c后產生多種結果,需要分別計算不同來源
                    # 經過c擴展后得到不同結果的概率,對特定事件的概率做出貢獻
                    if len(l) > 0 and c == l[-1]:
                        # A(t-1, l)中來自blank結尾的路徑經過c擴展得到l_plus,Anb(t, l_plus)發生,計算概率貢獻
                        Pnb[t][l_plus] += ctc[t][c_ix] * Pb[t - 1][l]
                        # A(t-1, l)中來自non-blank結尾的路徑經過c擴展,維持l不變,Anb(t, l)發生,計算概率貢獻
                        Pnb[t][l] += ctc[t][c_ix] * Pnb[t - 1][l]
                    # c既不是l末元素也不是blank,A(t-1, l)+c只有一種結果即Anb(t, l_plus),計算概率貢獻,但是
                    # 計算方式需要根據l當前狀態做調整
                    elif len(l.replace(' ', '')) > 0 and c in (' ', '>'):
                        lm_prob = lm(l_plus.strip(' >')) ** alpha
                        Pnb[t][l_plus] += lm_prob * ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l])
                    else:
                        Pnb[t][l_plus] += ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l])
                    # l_plus作為有可能加入A_next的新前綴卻沒有在上次出現在A_prev中,一種可能是由於
                    # beam width的限制,A(t-1, l_plus)概率沒排在前 k,導致沒被A_prev收錄,但是
                    # 在本次擴展結果中,再次需要考慮l_plus是否能加入A_next了,即計算A(t, l_plus)的
                    # 概率排名,但是由於l_plus不在A_prev中,只能考慮新擴展(l + c)的方式獲得的概率貢獻。
                    # 假設沒有beam width的限制,l_plus之前是在A_prev中的,那就還有兩種方式來在本次獲得概率貢獻,
                    # 一種是本次擴展一個blank,一種是本次擴展一個重復字符,這算是beam width限制下的一種查漏。
                    # 對大多數情形前一個時間步就生成了更長的l_plus是不會發生的,所以這一項大多時候不起作用。
                    if l_plus not in A_prev:
                        # A(t-1, l_plus) + blank得Ab(t, l_plus)
                        Pb[t][l_plus] += ctc[t][-1] * (Pb[t - 1][l_plus] + Pnb[t - 1][l_plus])
                        # Anb(t-1, l_plus) + c得Anb(t, l_plus)
                        Pnb[t][l_plus] += ctc[t][c_ix] * Pnb[t - 1][l_plus]
        A_next = Pb[t] + Pnb[t]
        sorter = (lambda l: A_next[l] * (len(W(l)) + 1) ** beta)
        A_prev = sorted(A_next, key=sorter, reverse=True)[:k]
    return A_prev[0].strip('>')

最后用的實現

# 負無窮大
NEG_INF = float('-inf')

def logsumexp(*args):
    '''
    log概率求和,即計算log(a + b)
    使用的公式是:
    log(a + b) = log(a) + log(1 + exp(log(b) - log(a)))
    '''
    # args中都是log scale的概率,負無窮代表真實概率為0
    if all(a == NEG_INF for a in args):
        return NEG_INF
    # 用序列最大值當公式中 a
    a_max = max(args)
    lsp = np.log(sum(np.exp(a - a_max) for a in args))
    return a_max + lsp

def prune_vocab(prob, vocab_list, accumulation, max_num=50):
    """ 詞匯表裁剪
    留下累積概率超過accumulation且最多不超過max_num的類別
    """
    assert prob.shape[0] == len(vocab_list)
    assert accumulation < 1.0
    # 累積概率裁剪
    prob = np.exp(prob)
    indices = np.argsort(prob)[::-1]
    prob_sorted = np.array([prob[ii] for ii in indices])
    prob_accumulated = np.add.accumulate(prob_sorted)
    index = np.where(prob_accumulated >= accumulation)[0][0]
    # 裁剪過后最多不超過max_num個候選
    index = min(index, max_num - 1)
    part_indices = indices[:index + 1]
    # 必須把blank的索引加進來,否則beam_ext有可能全是新的beam,已有的高概率beam會被丟棄
    # (此時只有重復元素折疊的情形能保留已有beam,+blank保留beam的情形被丟棄了)
    blank_idx = len(vocab_list) - 1
    if blank_idx not in part_indices:
        part_indices = np.append(part_indices, blank_idx)
    part_vocab_list = [vocab_list[ii] for ii in part_indices]
    return part_indices, part_vocab_list

def beam_search_decode(ctc, vocab_list, str2idx_table,
                       beam_width, alpha, beta,
                       accumulation, lm_func=None):
    '''
    ctc (np.ndarray) : CTC網絡的輸出,形狀為(time_steps x alphabet_size)
    vocab_list (list) : 類別字符列表,即alphabet
    beam_width (int) : beam search保存最大概率候選的數目
    alpha (float) : 語言模型的條件概率權重,取值0到1
    beta (float) : 語言模型的序列長度獎勵權重,避免長序列概率過小,alpha越大beta也應該越大
    accumulation (float) : 類別列表中累積概率超過這個數的類別才參與擴展
    lm_func (function) : 計算某個前綴的條件概率
    '''
    # 事件定義:
    # 用A(t, l)代表 Path[1:t] -> l,有下面兩種可能
    # 用A_b(t, l)代表 Path[1:t] -> l 且 l 末位為blank
    # 用A_nb(t, l)代表 Path[1:t] -> l 且 l 末位為non-blank

    # 不使用語言模型時始終返回真實概率1.0,對數概率0.0
    lm_func = (lambda x: 0.0) if lm_func is None else lm_func
    # pb[t][l]事件A_b(t, l)的概率,log概率初始化為負無窮
    pb = defaultdict(lambda: defaultdict(lambda: float('-inf')))
    # pnb[t][l]事件A_nb(t, l)的概率
    pnb = defaultdict(lambda: defaultdict(lambda: float('-inf')))
    # 事件A(t, l)的概率為二者之和pb[t][l] + pnb[t][l]

    # 路徑擴展除了vocab_list之外還包含blank,用%表示
    vocab_list = vocab_list + ['%']
    # 為ctc擴展一個想象出來的時間0
    num_classes = ctc.shape[1]
    ctc = np.vstack((np.zeros((num_classes,), dtype=ctc.dtype), ctc))
    ctc = np.log(ctc)
    # 比實際時間多一個
    T = ctc.shape[0]

    empty_prefix = '^'
    # 真實概率1.0
    pb[0][empty_prefix] = 0.0
    # 真實概率0.0
    pnb[0][empty_prefix] = NEG_INF
    # beams最多存放beam_width個概率最大的(真)前綴,初始化為空前綴
    beams = [empty_prefix]

    for t in range(1, T):
        # 同一個時間步下beams里的不同前綴可能會擴展出相同的新前綴,所以下面在計算
        # 某個前綴的概率時不用賦值而用增量,表示兩種來源都會對這一事件有概率貢獻
        # 例子: beams里同時有BO和BOX,BO通過添加X擴展為BOX,而BOX擴展X維持不變
        # A(t-1,BO)+X和A(t-1,BOX)+X此二者都會促成事件A_nb(t,BOX)但是來源不同是互斥的。
        beams_ext = []
        ppstart = time.time()
        part_indices, part_vocab_list = prune_vocab(ctc[t], vocab_list, accumulation)
        if t == 1:
            print('prune vocab cost {}'.format(time.time() - ppstart))
        for l in beams:
            # 還沒擴展之前,時間還是t-1,每一個l都表示事件A(t-1, l)
            # 遍歷vocab_list擴展出新事件,討論新事件的概率
            # for c_idx, c in enumerate(vocab_list):
            for c_idx, c in zip(part_indices, part_vocab_list):
                if c == '%':
                    # A(t-1, l) + blank,前綴不變、時間+1、末位為blank,即得事件A_b(t, l)
                    # 計算這個事件的概率
                    # p_b[t][l] += (p_b[t - 1][l] + p_nb[t - 1][l]) * ctc[t, -1]
                    ll_start = time.time()
                    pb[t][l] = logsumexp(pb[t][l], pb[t - 1][l] + ctc[t, -1], pnb[t - 1][l] + ctc[t, -1])
                    # A(t, l)有發生的可能了,將l添加到beams_ext
                    if l not in beams_ext:
                        beams_ext.append(l)
                    if t == 1:
                        print('log_sum cost {}'.format(time.time() - ll_start))
                else:
                    # 以下c都為non-blank了
                    # 僅當覆蓋增加時前綴變化需要應用語言模型
                    # c_idx = str2idx_table(c)
                    l_plus = l + c
                    lm_prob = alpha * lm_func(l_plus)
                    if len(l) > 0 and c == l[-1]:
                        # A(t-1, l)中的兩種子事件A_b(t-1, l)和A_nb(t-1, l)經過重復末元素擴展
                        # 會得到不同的結果,得到兩個事件,分別計算之
                        # A_b(t-1, l) + l[-1]覆蓋增加得事件A_nb(t, l_plus)
                        # p_nb[t][l_plus] += p_b[t - 1][l] * ctc[t, c_idx] * lm_func(l_plus) ** alpha
                        pnb[t][l_plus] = logsumexp(pnb[t][l_plus],
                                                   pb[t - 1][l] + ctc[t, c_idx] + lm_prob)
                        # A_nb(t-1, l) + l[-1]發生重復元素折疊得事件A_nb(t, l)
                        # p_nb[t][l] += p_nb[t - 1][l] * ctc[t, c_idx]
                        pnb[t][l] = logsumexp(pnb[t][l],
                                              pnb[t - 1][l] + ctc[t, c_idx])
                    else:
                        # c既非重復末元素也非blank,A(t-1, l) + c覆蓋增加且末位非blank,
                        # 只有一個事件會發生——A_nb(t, l_plus),計算概率
                        # p_nb[t][l_plus] += (p_b[t - 1][l] + p_nb[t - 1][l]) * ctc[t, c_idx] * lm_func(l_plus) ** alpha
                        pnb[t][l_plus] = logsumexp(pnb[t][l_plus],
                                                   pb[t - 1][l] + ctc[t, c_idx] + lm_prob,
                                                   pnb[t - 1][l] + ctc[t, c_idx] + lm_prob)
                    # 不管那種情形A(t, l_plus)有可能發生了,將l_plus添加到beams_ext
                    if l_plus not in beams_ext:
                        beams_ext.append(l_plus)
        print('total {} candidates in beams_ext'.format(len(beams_ext)))
        # 將beams_ext中的前綴按概率排序
        # prefix_probs = [(p_b[t][ll] + p_nb[t][ll]) * (len(ll) + 1) ** beta for ll in beams_ext]
        # 概率
        prefix_probs = [logsumexp(pb[t][ll], pnb[t][ll]) for ll in beams_ext]
        # 排序分數,加入了序列長度獎勵
        sort_score = [logsumexp(pb[t][ll] + beta * (len(ll) + 1),
                                pnb[t][ll] + beta * (len(ll) + 1))
                      for ll in beams_ext]
        indices = sorted(range(len(sort_score)), key=lambda k: sort_score[k])[::-1][:beam_width]
        beams = [beams_ext[ii] for ii in indices]
        beams_probs = [prefix_probs[ii] for ii in indices]
        beams_scores = [sort_score[ii] for ii in indices]

        for ii in range(len(beams)):
            print('p(A({}, \'{}\')) = {} score = {}'.
                  format(t, beams[ii], beams_probs[ii], beams_scores[ii]))
        print('-' * 64)
    decoded = beams[0]
    return decoded





免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM