一、greedy search
二、beam search
三、prefix beam search
完整代码如下
"""
Author: Awni Hannun
This is an example CTC decoder written in Python. The code is
intended to be a simple example and is not designed to be
especially efficient.
The algorithm is a prefix beam search for a model trained
with the CTC loss function.
For more details checkout either of these references:
https://distill.pub/2017/ctc/#inference
https://arxiv.org/abs/1408.2873
"""
import numpy as np
import math
import collections
NEG_INF = -float("inf")
def make_new_beam():
fn = lambda : (NEG_INF, NEG_INF)
return collections.defaultdict(fn)
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp
def decode(probs, beam_size=100, blank=0):
"""
Performs inference for the given output probabilities.
Arguments:
probs: The output probabilities (e.g. post-softmax) for each
time step. Should be an array of shape (time x output dim).
beam_size (int): Size of the beam to use during inference.
blank (int): Index of the CTC blank label.
Returns the output label sequence and the corresponding negative
log-likelihood estimated by the decoder.
"""
T, S = probs.shape
probs = np.log(probs)
# Elements in the beam are (prefix, (p_blank, p_no_blank))
# Initialize the beam with the empty sequence, a probability of
# 1 for ending in blank and zero for ending in non-blank
# (in log space).
beam = [(tuple(), (0.0, NEG_INF))]
for t in range(T): # Loop over time
# A default dictionary to store the next step candidates.
next_beam = make_new_beam()
for s in range(S): # Loop over vocab
p = probs[t, s]
# The variables p_b and p_nb are respectively the
# probabilities for the prefix given that it ends in a
# blank and does not end in a blank at this time step.
for prefix, (p_b, p_nb) in beam: # Loop over beam
# If we propose a blank the prefix doesn't change.
# Only the probability of ending in blank gets updated.
if s == blank:
n_p_b, n_p_nb = next_beam[prefix]
n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
continue
# Extend the prefix by the new character s and add it to
# the beam. Only the probability of not ending in blank
# gets updated.
end_t = prefix[-1] if prefix else None
n_prefix = prefix + (s,)
n_p_b, n_p_nb = next_beam[n_prefix]
if s != end_t:
n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
else:
# We don't include the previous probability of not ending
# in blank (p_nb) if s is repeated at the end. The CTC
# algorithm merges characters not separated by a blank.
n_p_nb = logsumexp(n_p_nb, p_b + p)
# *NB* this would be a good place to include an LM score.
next_beam[n_prefix] = (n_p_b, n_p_nb)
# If s is repeated at the end we also update the unchanged
# prefix. This is the merging case.
if s == end_t:
n_p_b, n_p_nb = next_beam[prefix]
n_p_nb = logsumexp(n_p_nb, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
# Sort and trim the beam before moving on to the
# next time-step.
beam = sorted(next_beam.items(),
key=lambda x : logsumexp(*x[1]),
reverse=True)
beam = beam[:beam_size]
best = beam[0]
return best[0], -logsumexp(*best[1])
if __name__ == "__main__":
np.random.seed(3)
time = 50
output_dim = 20
probs = np.random.rand(time, output_dim)
probs = probs / np.sum(probs, axis=1, keepdims=True)
labels, score = decode(probs)
print("Score {:.3f}".format(score))
首先beam是一个存放着形如 (prefix, (p_blank, p_no_blank))元素的列表,prefix表示到当前timestep为止的一条路径,注意这条路径是经过变换的,即去除连续重复字符和blank。p_blank和p_no_blank分别表示经过变换前所有以blank结尾的路径概率和和以非blank结尾的所有路径概率和,例如在T=3时,-ab-和abbb两条路径经过变换都转换成ab,也就是当前的prefix,但当T=4时取字符b,-ab-b就变成了abb,而abbbb转换后还是ab,因此这里保存两种情况的概率以便计算。
一开始beam里第一个元素初始化成 (tuple(), (0, NEG_INF)),即prefix为空,NEG_INF是负无穷,这里的概率是转换到对数域后的,经过exp还原回去变成p_blank=1, p_no_blank=0,即以blank结尾的概率为1,以非blank结尾的概率为0。
最外层循环遍历T,每到一个新的timestep,调用make_new_beam函数生成一个next_beam,beam里存放的是遍历完上一个timestep所有字符后的路径和概率,next_beam存放的是遍历完当前timestep所有字符更新后的路径和概率,next_beam初始化为空,传入任何prefix返回n_p_b和n_p_nb都为NEG_INF,即概率为0,n代表new,和beam里的p_b、p_nb对应。
函数logsumexp计算的是路径的和,具体先通过exp将路径的对数域概率还原,然后将所有路径概率相加,再通过log转换回对数域。其中每个概率先减去a_max最后结果再加上a_max可参考CTC Loss(二)中对数域优化(三)。
if s == blank:
n_p_b, n_p_nb = next_beam[prefix]
n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
当遍历到当前的字符s是blank时,只更新n_p_b,p_b+p和p_nb+p实际上是概率相乘,因为,p_b和p_nb分别是前一时刻prefix以blank结尾和以非blank字符结尾的概率,乘上当前timestep字符blank的概率,prefix没有变化,以blank结尾的概率n_p_b得到更新。
当遍历到当前的字符s不是blank时,还要分s和prefix中的最后一个字符是否一样两种情况。当不一样时,prefix要加上当前字符s,更新以非blank结尾的路径概率,代码如下
end_t = prefix[-1] if prefix else None
n_prefix = prefix + (s,)
n_p_b, n_p_nb = next_beam[n_prefix]
if s != end_t:
n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
next_beam[n_prefix] = (n_p_b, n_p_nb)
当一样时, 当更新以非blank结尾的路径概率n_p_nb时,不包括前一时刻以非blank结尾的路径概率p_nb,因为我们现在更新的概率是prefix+s的概率,连续重复字符按规则要删去。例如假设prefix为ab,-ab-和abbb分别是以blank和非blank结尾并且经过变换转换为prefix的众多路径中的一条,当前字符s='b',则prefix+s=abb,-ab-b会转换为abb,但abbbb转换后还是ab,因为只更新以blank结尾的概率p_b。代码如下
else:
# We don't include the previous probability of not ending
# in blank (p_nb) if s is repeated at the end. The CTC
# algorithm merges characters not separated by a blank.
n_p_nb = logsumexp(n_p_nb, p_b + p)
next_beam[n_prefix] = (n_p_b, n_p_nb)
上面的例子中abbb+s=abbbb,转换后还是ab,因此虽然prefix+s的概率不更新,但是需要更新原prefix的概率。这时从next_beam里取的是原始prefix的概率,并且更新n_p_nb时不包括以blank结尾的概率p_b,因为p_b更新的概率是包含在prefix+s中的。代码如下
if s == end_t:
n_p_b, n_p_nb = next_beam[prefix]
n_p_nb = logsumexp(n_p_nb, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
遍历timestep时,每次更新到一个新的,创建一个空的next_beam,在当前中,遍历每个字符以及前一时刻挑选出的路径beam,在当前时刻更新完时,对next_beam种的所有路径按概率进行排序,取概率最大的beam_size条路径继续更新。注意对概率排序时,每个prefix的概率是以blank结尾和非blank结尾概率的和。代码如下
# Sort and trim the beam before moving on to the
# next time-step.
beam = sorted(next_beam.items(),
key=lambda x: logsumexp(*x[1]),
reverse=True)
beam = beam[:beam_size]