ctc prefix beam search 算法
CTC 網絡的輸出 net_out 形狀為
,其中
是時間長度,
是字符類別數加1(額外的blank)。
CTC 的 beam search 算法維護的不是 K 個路徑前綴,而是 K 個標簽前綴,但仍需要考慮其背后的路徑(路徑到標簽的多對一關系)。每個時間步,對 K 個前綴進行擴展,用字符表中的字符對已有前綴做擴展,得到新的多個前綴,然后計算這些前綴的概率,從中挑選出概率最大的 K 個保存,不斷重復這個過程直到最后一個時間步,然后選出概率最大的一個結果作為最終的標簽。
第一種實現
# -*-coding:utf-8-*-
from collections import defaultdict, Counter
from string import ascii_lowercase
import re
import numpy as np
def prefix_beam_search(ctc, lm=None, k=25, alpha=0.30, beta=5, prune=0.001):
"""
對CTC網絡的輸出做beam search
Args:
ctc (np.ndarray): The CTC網絡輸出. 2D array形狀為(timesteps x alphabet_size)
lm (func): 語言模型函數. 接收一個字符串做參數,輸出一個概率
k (int): beam寬度. 每個時間步將保存 k 個概率最大的候選前綴
alpha (float): 語言模型的權重,取值在0到1
beta (float): 語言模型的懲罰(獎勵)項權重. alpha越大,beta應該越大
prune (float): ctc的每個時間步的輸出分布中概率大於prune的才參與前綴擴展
Retruns:
string: 返回解碼結果
"""
# 沒有提供語言模型則始終返回1
lm = (lambda l: 1) if lm is None else lm
# 正則匹配l中所有的單詞
W = (lambda l: re.findall(r'\w+[\s|>]', l))
# 字母表是小寫英文字母及空格、結束符、空白標簽
alphabet = list(ascii_lowercase) + [' ', '>', '%']
F = ctc.shape[1]
# 對ctc輸出添加一個想象中的時間步0,用於初始化空,手動傳個blank進來
ctc = np.vstack((np.zeros(F), ctc))
T = ctc.shape[0]
# 空前綴
origin = ''
# 每個時間步下beam里雖然存的是前綴集合,但需要考慮其背后對應的路徑集合
# Pb表示前綴由那些以blank結尾的路徑生成的概率,Pnb表示前綴由那些以non-blank結尾的路徑生成的概率
Pb, Pnb = defaultdict(Counter), defaultdict(Counter)
# 因為手動傳blank進來,所以以blank結尾且前綴為空的概率為1
Pb[0][origin] = 1
# non-blank結尾的路徑生成空前綴,不可能發生
Pnb[0][origin] = 0
# A_prev保存當前時間步開始擴展之前所保留的概率最大的候選前綴集合,數目小於等於 k
A_prev = [origin]
# 不斷的擴展前綴,始終保留概率最大的 k 個
# 路徑空間上的事件定義:
# A(t, l)代表到t步為止生成l,
# Ab(t, l)代表到t步為止生成l且末位為blank,
# Anb(t, l)代表到t步為止生成l且末位為non-blank
for t in range(1, T):
# 對當前時間步的字母表分布,選取概率大於prune的,減少運算
pruned_alphabet = [alphabet[i] for i in np.where(ctc[t] > prune)[0]]
# 因為同一個時間步A_prev里不同的前綴可能會擴展出相同的新前綴,所以概率用增量式而不是賦值
for l in A_prev:
# 當前前綴已經到句末了,不能再擴展
if len(l) > 0 and l[-1] == '>':
Pb[t][l] = Pb[t - 1][l]
Pnb[t][l] = Pnb[t - 1][l]
continue
# 每個l都代表着A(t-1, l)事件
for c in pruned_alphabet:
c_ix = alphabet.index(c)
# A(t-1, l)遇到blank,只有一種結果,即Ab(t, l)
# 計算概率貢獻P(t-1, l) * P(blank, t)
if c == '%':
Pb[t][l] += ctc[t][-1] * (Pb[t - 1][l] + Pnb[t - 1][l])
else:
l_plus = l + c
# l中有多種路徑來源,遇到c后產生多種結果,需要分別計算不同來源
# 經過c擴展后得到不同結果的概率,對特定事件的概率做出貢獻
if len(l) > 0 and c == l[-1]:
# A(t-1, l)中來自blank結尾的路徑經過c擴展得到l_plus,Anb(t, l_plus)發生,計算概率貢獻
Pnb[t][l_plus] += ctc[t][c_ix] * Pb[t - 1][l]
# A(t-1, l)中來自non-blank結尾的路徑經過c擴展,維持l不變,Anb(t, l)發生,計算概率貢獻
Pnb[t][l] += ctc[t][c_ix] * Pnb[t - 1][l]
# c既不是l末元素也不是blank,A(t-1, l)+c只有一種結果即Anb(t, l_plus),計算概率貢獻,但是
# 計算方式需要根據l當前狀態做調整
elif len(l.replace(' ', '')) > 0 and c in (' ', '>'):
lm_prob = lm(l_plus.strip(' >')) ** alpha
Pnb[t][l_plus] += lm_prob * ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l])
else:
Pnb[t][l_plus] += ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l])
# l_plus作為有可能加入A_next的新前綴卻沒有在上次出現在A_prev中,一種可能是由於
# beam width的限制,A(t-1, l_plus)概率沒排在前 k,導致沒被A_prev收錄,但是
# 在本次擴展結果中,再次需要考慮l_plus是否能加入A_next了,即計算A(t, l_plus)的
# 概率排名,但是由於l_plus不在A_prev中,只能考慮新擴展(l + c)的方式獲得的概率貢獻。
# 假設沒有beam width的限制,l_plus之前是在A_prev中的,那就還有兩種方式來在本次獲得概率貢獻,
# 一種是本次擴展一個blank,一種是本次擴展一個重復字符,這算是beam width限制下的一種查漏。
# 對大多數情形前一個時間步就生成了更長的l_plus是不會發生的,所以這一項大多時候不起作用。
if l_plus not in A_prev:
# A(t-1, l_plus) + blank得Ab(t, l_plus)
Pb[t][l_plus] += ctc[t][-1] * (Pb[t - 1][l_plus] + Pnb[t - 1][l_plus])
# Anb(t-1, l_plus) + c得Anb(t, l_plus)
Pnb[t][l_plus] += ctc[t][c_ix] * Pnb[t - 1][l_plus]
A_next = Pb[t] + Pnb[t]
sorter = (lambda l: A_next[l] * (len(W(l)) + 1) ** beta)
A_prev = sorted(A_next, key=sorter, reverse=True)[:k]
return A_prev[0].strip('>')
最后用的實現
# 負無窮大
NEG_INF = float('-inf')
def logsumexp(*args):
'''
log概率求和,即計算log(a + b)
使用的公式是:
log(a + b) = log(a) + log(1 + exp(log(b) - log(a)))
'''
# args中都是log scale的概率,負無窮代表真實概率為0
if all(a == NEG_INF for a in args):
return NEG_INF
# 用序列最大值當公式中 a
a_max = max(args)
lsp = np.log(sum(np.exp(a - a_max) for a in args))
return a_max + lsp
def prune_vocab(prob, vocab_list, accumulation, max_num=50):
""" 詞匯表裁剪
留下累積概率超過accumulation且最多不超過max_num的類別
"""
assert prob.shape[0] == len(vocab_list)
assert accumulation < 1.0
# 累積概率裁剪
prob = np.exp(prob)
indices = np.argsort(prob)[::-1]
prob_sorted = np.array([prob[ii] for ii in indices])
prob_accumulated = np.add.accumulate(prob_sorted)
index = np.where(prob_accumulated >= accumulation)[0][0]
# 裁剪過后最多不超過max_num個候選
index = min(index, max_num - 1)
part_indices = indices[:index + 1]
# 必須把blank的索引加進來,否則beam_ext有可能全是新的beam,已有的高概率beam會被丟棄
# (此時只有重復元素折疊的情形能保留已有beam,+blank保留beam的情形被丟棄了)
blank_idx = len(vocab_list) - 1
if blank_idx not in part_indices:
part_indices = np.append(part_indices, blank_idx)
part_vocab_list = [vocab_list[ii] for ii in part_indices]
return part_indices, part_vocab_list
def beam_search_decode(ctc, vocab_list, str2idx_table,
beam_width, alpha, beta,
accumulation, lm_func=None):
'''
ctc (np.ndarray) : CTC網絡的輸出,形狀為(time_steps x alphabet_size)
vocab_list (list) : 類別字符列表,即alphabet
beam_width (int) : beam search保存最大概率候選的數目
alpha (float) : 語言模型的條件概率權重,取值0到1
beta (float) : 語言模型的序列長度獎勵權重,避免長序列概率過小,alpha越大beta也應該越大
accumulation (float) : 類別列表中累積概率超過這個數的類別才參與擴展
lm_func (function) : 計算某個前綴的條件概率
'''
# 事件定義:
# 用A(t, l)代表 Path[1:t] -> l,有下面兩種可能
# 用A_b(t, l)代表 Path[1:t] -> l 且 l 末位為blank
# 用A_nb(t, l)代表 Path[1:t] -> l 且 l 末位為non-blank
# 不使用語言模型時始終返回真實概率1.0,對數概率0.0
lm_func = (lambda x: 0.0) if lm_func is None else lm_func
# pb[t][l]事件A_b(t, l)的概率,log概率初始化為負無窮
pb = defaultdict(lambda: defaultdict(lambda: float('-inf')))
# pnb[t][l]事件A_nb(t, l)的概率
pnb = defaultdict(lambda: defaultdict(lambda: float('-inf')))
# 事件A(t, l)的概率為二者之和pb[t][l] + pnb[t][l]
# 路徑擴展除了vocab_list之外還包含blank,用%表示
vocab_list = vocab_list + ['%']
# 為ctc擴展一個想象出來的時間0
num_classes = ctc.shape[1]
ctc = np.vstack((np.zeros((num_classes,), dtype=ctc.dtype), ctc))
ctc = np.log(ctc)
# 比實際時間多一個
T = ctc.shape[0]
empty_prefix = '^'
# 真實概率1.0
pb[0][empty_prefix] = 0.0
# 真實概率0.0
pnb[0][empty_prefix] = NEG_INF
# beams最多存放beam_width個概率最大的(真)前綴,初始化為空前綴
beams = [empty_prefix]
for t in range(1, T):
# 同一個時間步下beams里的不同前綴可能會擴展出相同的新前綴,所以下面在計算
# 某個前綴的概率時不用賦值而用增量,表示兩種來源都會對這一事件有概率貢獻
# 例子: beams里同時有BO和BOX,BO通過添加X擴展為BOX,而BOX擴展X維持不變
# A(t-1,BO)+X和A(t-1,BOX)+X此二者都會促成事件A_nb(t,BOX)但是來源不同是互斥的。
beams_ext = []
ppstart = time.time()
part_indices, part_vocab_list = prune_vocab(ctc[t], vocab_list, accumulation)
if t == 1:
print('prune vocab cost {}'.format(time.time() - ppstart))
for l in beams:
# 還沒擴展之前,時間還是t-1,每一個l都表示事件A(t-1, l)
# 遍歷vocab_list擴展出新事件,討論新事件的概率
# for c_idx, c in enumerate(vocab_list):
for c_idx, c in zip(part_indices, part_vocab_list):
if c == '%':
# A(t-1, l) + blank,前綴不變、時間+1、末位為blank,即得事件A_b(t, l)
# 計算這個事件的概率
# p_b[t][l] += (p_b[t - 1][l] + p_nb[t - 1][l]) * ctc[t, -1]
ll_start = time.time()
pb[t][l] = logsumexp(pb[t][l], pb[t - 1][l] + ctc[t, -1], pnb[t - 1][l] + ctc[t, -1])
# A(t, l)有發生的可能了,將l添加到beams_ext
if l not in beams_ext:
beams_ext.append(l)
if t == 1:
print('log_sum cost {}'.format(time.time() - ll_start))
else:
# 以下c都為non-blank了
# 僅當覆蓋增加時前綴變化需要應用語言模型
# c_idx = str2idx_table(c)
l_plus = l + c
lm_prob = alpha * lm_func(l_plus)
if len(l) > 0 and c == l[-1]:
# A(t-1, l)中的兩種子事件A_b(t-1, l)和A_nb(t-1, l)經過重復末元素擴展
# 會得到不同的結果,得到兩個事件,分別計算之
# A_b(t-1, l) + l[-1]覆蓋增加得事件A_nb(t, l_plus)
# p_nb[t][l_plus] += p_b[t - 1][l] * ctc[t, c_idx] * lm_func(l_plus) ** alpha
pnb[t][l_plus] = logsumexp(pnb[t][l_plus],
pb[t - 1][l] + ctc[t, c_idx] + lm_prob)
# A_nb(t-1, l) + l[-1]發生重復元素折疊得事件A_nb(t, l)
# p_nb[t][l] += p_nb[t - 1][l] * ctc[t, c_idx]
pnb[t][l] = logsumexp(pnb[t][l],
pnb[t - 1][l] + ctc[t, c_idx])
else:
# c既非重復末元素也非blank,A(t-1, l) + c覆蓋增加且末位非blank,
# 只有一個事件會發生——A_nb(t, l_plus),計算概率
# p_nb[t][l_plus] += (p_b[t - 1][l] + p_nb[t - 1][l]) * ctc[t, c_idx] * lm_func(l_plus) ** alpha
pnb[t][l_plus] = logsumexp(pnb[t][l_plus],
pb[t - 1][l] + ctc[t, c_idx] + lm_prob,
pnb[t - 1][l] + ctc[t, c_idx] + lm_prob)
# 不管那種情形A(t, l_plus)有可能發生了,將l_plus添加到beams_ext
if l_plus not in beams_ext:
beams_ext.append(l_plus)
print('total {} candidates in beams_ext'.format(len(beams_ext)))
# 將beams_ext中的前綴按概率排序
# prefix_probs = [(p_b[t][ll] + p_nb[t][ll]) * (len(ll) + 1) ** beta for ll in beams_ext]
# 概率
prefix_probs = [logsumexp(pb[t][ll], pnb[t][ll]) for ll in beams_ext]
# 排序分數,加入了序列長度獎勵
sort_score = [logsumexp(pb[t][ll] + beta * (len(ll) + 1),
pnb[t][ll] + beta * (len(ll) + 1))
for ll in beams_ext]
indices = sorted(range(len(sort_score)), key=lambda k: sort_score[k])[::-1][:beam_width]
beams = [beams_ext[ii] for ii in indices]
beams_probs = [prefix_probs[ii] for ii in indices]
beams_scores = [sort_score[ii] for ii in indices]
for ii in range(len(beams)):
print('p(A({}, \'{}\')) = {} score = {}'.
format(t, beams[ii], beams_probs[ii], beams_scores[ii]))
print('-' * 64)
decoded = beams[0]
return decoded