ctc prefix beam search 算法
CTC 网络的输出 net_out 形状为 T×C,其中 T 是时间长度,C 是字符类别数加1(额外的blank)。
CTC 的 beam search 算法维护的不是 K 个路径前缀,而是 K 个标签前缀,但仍需要考虑其背后的路径(路径到标签的多对一关系)。每个时间步,对 K 个前缀进行扩展,用字符表中的字符对已有前缀做扩展,得到新的多个前缀,然后计算这些前缀的概率,从中挑选出概率最大的 K 个保存,不断重复这个过程直到最后一个时间步,然后选出概率最大的一个结果作为最终的标签。
第一种实现
# -*-coding:utf-8-*-
from collections import defaultdict, Counter
from string import ascii_lowercase
import re
import numpy as np
def prefix_beam_search(ctc, lm=None, k=25, alpha=0.30, beta=5, prune=0.001):
"""
对CTC网络的输出做beam search
Args:
ctc (np.ndarray): The CTC网络输出. 2D array形状为(timesteps x alphabet_size)
lm (func): 语言模型函数. 接收一个字符串做参数,输出一个概率
k (int): beam宽度. 每个时间步将保存 k 个概率最大的候选前缀
alpha (float): 语言模型的权重,取值在0到1
beta (float): 语言模型的惩罚(奖励)项权重. alpha越大,beta应该越大
prune (float): ctc的每个时间步的输出分布中概率大于prune的才参与前缀扩展
Retruns:
string: 返回解码结果
"""
# 没有提供语言模型则始终返回1
lm = (lambda l: 1) if lm is None else lm
# 正则匹配l中所有的单词
W = (lambda l: re.findall(r'\w+[\s|>]', l))
# 字母表是小写英文字母及空格、结束符、空白标签
alphabet = list(ascii_lowercase) + [' ', '>', '%']
F = ctc.shape[1]
# 对ctc输出添加一个想象中的时间步0,用于初始化空,手动传个blank进来
ctc = np.vstack((np.zeros(F), ctc))
T = ctc.shape[0]
# 空前缀
origin = ''
# 每个时间步下beam里虽然存的是前缀集合,但需要考虑其背后对应的路径集合
# Pb表示前缀由那些以blank结尾的路径生成的概率,Pnb表示前缀由那些以non-blank结尾的路径生成的概率
Pb, Pnb = defaultdict(Counter), defaultdict(Counter)
# 因为手动传blank进来,所以以blank结尾且前缀为空的概率为1
Pb[0][origin] = 1
# non-blank结尾的路径生成空前缀,不可能发生
Pnb[0][origin] = 0
# A_prev保存当前时间步开始扩展之前所保留的概率最大的候选前缀集合,数目小于等于 k
A_prev = [origin]
# 不断的扩展前缀,始终保留概率最大的 k 个
# 路径空间上的事件定义:
# A(t, l)代表到t步为止生成l,
# Ab(t, l)代表到t步为止生成l且末位为blank,
# Anb(t, l)代表到t步为止生成l且末位为non-blank
for t in range(1, T):
# 对当前时间步的字母表分布,选取概率大于prune的,减少运算
pruned_alphabet = [alphabet[i] for i in np.where(ctc[t] > prune)[0]]
# 因为同一个时间步A_prev里不同的前缀可能会扩展出相同的新前缀,所以概率用增量式而不是赋值
for l in A_prev:
# 当前前缀已经到句末了,不能再扩展
if len(l) > 0 and l[-1] == '>':
Pb[t][l] = Pb[t - 1][l]
Pnb[t][l] = Pnb[t - 1][l]
continue
# 每个l都代表着A(t-1, l)事件
for c in pruned_alphabet:
c_ix = alphabet.index(c)
# A(t-1, l)遇到blank,只有一种结果,即Ab(t, l)
# 计算概率贡献P(t-1, l) * P(blank, t)
if c == '%':
Pb[t][l] += ctc[t][-1] * (Pb[t - 1][l] + Pnb[t - 1][l])
else:
l_plus = l + c
# l中有多种路径来源,遇到c后产生多种结果,需要分别计算不同来源
# 经过c扩展后得到不同结果的概率,对特定事件的概率做出贡献
if len(l) > 0 and c == l[-1]:
# A(t-1, l)中来自blank结尾的路径经过c扩展得到l_plus,Anb(t, l_plus)发生,计算概率贡献
Pnb[t][l_plus] += ctc[t][c_ix] * Pb[t - 1][l]
# A(t-1, l)中来自non-blank结尾的路径经过c扩展,维持l不变,Anb(t, l)发生,计算概率贡献
Pnb[t][l] += ctc[t][c_ix] * Pnb[t - 1][l]
# c既不是l末元素也不是blank,A(t-1, l)+c只有一种结果即Anb(t, l_plus),计算概率贡献,但是
# 计算方式需要根据l当前状态做调整
elif len(l.replace(' ', '')) > 0 and c in (' ', '>'):
lm_prob = lm(l_plus.strip(' >')) ** alpha
Pnb[t][l_plus] += lm_prob * ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l])
else:
Pnb[t][l_plus] += ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l])
# l_plus作为有可能加入A_next的新前缀却没有在上次出现在A_prev中,一种可能是由于
# beam width的限制,A(t-1, l_plus)概率没排在前 k,导致没被A_prev收录,但是
# 在本次扩展结果中,再次需要考虑l_plus是否能加入A_next了,即计算A(t, l_plus)的
# 概率排名,但是由于l_plus不在A_prev中,只能考虑新扩展(l + c)的方式获得的概率贡献。
# 假设没有beam width的限制,l_plus之前是在A_prev中的,那就还有两种方式来在本次获得概率贡献,
# 一种是本次扩展一个blank,一种是本次扩展一个重复字符,这算是beam width限制下的一种查漏。
# 对大多数情形前一个时间步就生成了更长的l_plus是不会发生的,所以这一项大多时候不起作用。
if l_plus not in A_prev:
# A(t-1, l_plus) + blank得Ab(t, l_plus)
Pb[t][l_plus] += ctc[t][-1] * (Pb[t - 1][l_plus] + Pnb[t - 1][l_plus])
# Anb(t-1, l_plus) + c得Anb(t, l_plus)
Pnb[t][l_plus] += ctc[t][c_ix] * Pnb[t - 1][l_plus]
A_next = Pb[t] + Pnb[t]
sorter = (lambda l: A_next[l] * (len(W(l)) + 1) ** beta)
A_prev = sorted(A_next, key=sorter, reverse=True)[:k]
return A_prev[0].strip('>')
最后用的实现
# 负无穷大
NEG_INF = float('-inf')
def logsumexp(*args):
'''
log概率求和,即计算log(a + b)
使用的公式是:
log(a + b) = log(a) + log(1 + exp(log(b) - log(a)))
'''
# args中都是log scale的概率,负无穷代表真实概率为0
if all(a == NEG_INF for a in args):
return NEG_INF
# 用序列最大值当公式中 a
a_max = max(args)
lsp = np.log(sum(np.exp(a - a_max) for a in args))
return a_max + lsp
def prune_vocab(prob, vocab_list, accumulation, max_num=50):
""" 词汇表裁剪
留下累积概率超过accumulation且最多不超过max_num的类别
"""
assert prob.shape[0] == len(vocab_list)
assert accumulation < 1.0
# 累积概率裁剪
prob = np.exp(prob)
indices = np.argsort(prob)[::-1]
prob_sorted = np.array([prob[ii] for ii in indices])
prob_accumulated = np.add.accumulate(prob_sorted)
index = np.where(prob_accumulated >= accumulation)[0][0]
# 裁剪过后最多不超过max_num个候选
index = min(index, max_num - 1)
part_indices = indices[:index + 1]
# 必须把blank的索引加进来,否则beam_ext有可能全是新的beam,已有的高概率beam会被丢弃
# (此时只有重复元素折叠的情形能保留已有beam,+blank保留beam的情形被丢弃了)
blank_idx = len(vocab_list) - 1
if blank_idx not in part_indices:
part_indices = np.append(part_indices, blank_idx)
part_vocab_list = [vocab_list[ii] for ii in part_indices]
return part_indices, part_vocab_list
def beam_search_decode(ctc, vocab_list, str2idx_table,
beam_width, alpha, beta,
accumulation, lm_func=None):
'''
ctc (np.ndarray) : CTC网络的输出,形状为(time_steps x alphabet_size)
vocab_list (list) : 类别字符列表,即alphabet
beam_width (int) : beam search保存最大概率候选的数目
alpha (float) : 语言模型的条件概率权重,取值0到1
beta (float) : 语言模型的序列长度奖励权重,避免长序列概率过小,alpha越大beta也应该越大
accumulation (float) : 类别列表中累积概率超过这个数的类别才参与扩展
lm_func (function) : 计算某个前缀的条件概率
'''
# 事件定义:
# 用A(t, l)代表 Path[1:t] -> l,有下面两种可能
# 用A_b(t, l)代表 Path[1:t] -> l 且 l 末位为blank
# 用A_nb(t, l)代表 Path[1:t] -> l 且 l 末位为non-blank
# 不使用语言模型时始终返回真实概率1.0,对数概率0.0
lm_func = (lambda x: 0.0) if lm_func is None else lm_func
# pb[t][l]事件A_b(t, l)的概率,log概率初始化为负无穷
pb = defaultdict(lambda: defaultdict(lambda: float('-inf')))
# pnb[t][l]事件A_nb(t, l)的概率
pnb = defaultdict(lambda: defaultdict(lambda: float('-inf')))
# 事件A(t, l)的概率为二者之和pb[t][l] + pnb[t][l]
# 路径扩展除了vocab_list之外还包含blank,用%表示
vocab_list = vocab_list + ['%']
# 为ctc扩展一个想象出来的时间0
num_classes = ctc.shape[1]
ctc = np.vstack((np.zeros((num_classes,), dtype=ctc.dtype), ctc))
ctc = np.log(ctc)
# 比实际时间多一个
T = ctc.shape[0]
empty_prefix = '^'
# 真实概率1.0
pb[0][empty_prefix] = 0.0
# 真实概率0.0
pnb[0][empty_prefix] = NEG_INF
# beams最多存放beam_width个概率最大的(真)前缀,初始化为空前缀
beams = [empty_prefix]
for t in range(1, T):
# 同一个时间步下beams里的不同前缀可能会扩展出相同的新前缀,所以下面在计算
# 某个前缀的概率时不用赋值而用增量,表示两种来源都会对这一事件有概率贡献
# 例子: beams里同时有BO和BOX,BO通过添加X扩展为BOX,而BOX扩展X维持不变
# A(t-1,BO)+X和A(t-1,BOX)+X此二者都会促成事件A_nb(t,BOX)但是来源不同是互斥的。
beams_ext = []
ppstart = time.time()
part_indices, part_vocab_list = prune_vocab(ctc[t], vocab_list, accumulation)
if t == 1:
print('prune vocab cost {}'.format(time.time() - ppstart))
for l in beams:
# 还没扩展之前,时间还是t-1,每一个l都表示事件A(t-1, l)
# 遍历vocab_list扩展出新事件,讨论新事件的概率
# for c_idx, c in enumerate(vocab_list):
for c_idx, c in zip(part_indices, part_vocab_list):
if c == '%':
# A(t-1, l) + blank,前缀不变、时间+1、末位为blank,即得事件A_b(t, l)
# 计算这个事件的概率
# p_b[t][l] += (p_b[t - 1][l] + p_nb[t - 1][l]) * ctc[t, -1]
ll_start = time.time()
pb[t][l] = logsumexp(pb[t][l], pb[t - 1][l] + ctc[t, -1], pnb[t - 1][l] + ctc[t, -1])
# A(t, l)有发生的可能了,将l添加到beams_ext
if l not in beams_ext:
beams_ext.append(l)
if t == 1:
print('log_sum cost {}'.format(time.time() - ll_start))
else:
# 以下c都为non-blank了
# 仅当覆盖增加时前缀变化需要应用语言模型
# c_idx = str2idx_table(c)
l_plus = l + c
lm_prob = alpha * lm_func(l_plus)
if len(l) > 0 and c == l[-1]:
# A(t-1, l)中的两种子事件A_b(t-1, l)和A_nb(t-1, l)经过重复末元素扩展
# 会得到不同的结果,得到两个事件,分别计算之
# A_b(t-1, l) + l[-1]覆盖增加得事件A_nb(t, l_plus)
# p_nb[t][l_plus] += p_b[t - 1][l] * ctc[t, c_idx] * lm_func(l_plus) ** alpha
pnb[t][l_plus] = logsumexp(pnb[t][l_plus],
pb[t - 1][l] + ctc[t, c_idx] + lm_prob)
# A_nb(t-1, l) + l[-1]发生重复元素折叠得事件A_nb(t, l)
# p_nb[t][l] += p_nb[t - 1][l] * ctc[t, c_idx]
pnb[t][l] = logsumexp(pnb[t][l],
pnb[t - 1][l] + ctc[t, c_idx])
else:
# c既非重复末元素也非blank,A(t-1, l) + c覆盖增加且末位非blank,
# 只有一个事件会发生——A_nb(t, l_plus),计算概率
# p_nb[t][l_plus] += (p_b[t - 1][l] + p_nb[t - 1][l]) * ctc[t, c_idx] * lm_func(l_plus) ** alpha
pnb[t][l_plus] = logsumexp(pnb[t][l_plus],
pb[t - 1][l] + ctc[t, c_idx] + lm_prob,
pnb[t - 1][l] + ctc[t, c_idx] + lm_prob)
# 不管那种情形A(t, l_plus)有可能发生了,将l_plus添加到beams_ext
if l_plus not in beams_ext:
beams_ext.append(l_plus)
print('total {} candidates in beams_ext'.format(len(beams_ext)))
# 将beams_ext中的前缀按概率排序
# prefix_probs = [(p_b[t][ll] + p_nb[t][ll]) * (len(ll) + 1) ** beta for ll in beams_ext]
# 概率
prefix_probs = [logsumexp(pb[t][ll], pnb[t][ll]) for ll in beams_ext]
# 排序分数,加入了序列长度奖励
sort_score = [logsumexp(pb[t][ll] + beta * (len(ll) + 1),
pnb[t][ll] + beta * (len(ll) + 1))
for ll in beams_ext]
indices = sorted(range(len(sort_score)), key=lambda k: sort_score[k])[::-1][:beam_width]
beams = [beams_ext[ii] for ii in indices]
beams_probs = [prefix_probs[ii] for ii in indices]
beams_scores = [sort_score[ii] for ii in indices]
for ii in range(len(beams)):
print('p(A({}, \'{}\')) = {} score = {}'.
format(t, beams[ii], beams_probs[ii], beams_scores[ii]))
print('-' * 64)
decoded = beams[0]
return decoded