CTC loss的几种解码方法:贪心搜索 (greedy search)、束搜索(Beam Search)、前缀束搜索(Prefix Beam Search)
前言:
预测新的样本输入对应的输出字符串,这涉及到解码。按照最大似然准则,最优的解码结果为:
例:
如上图的例子,按照时间序列展开得到栅格网络,解码的过程相当于空间搜索, 求取穷举的所有可能字符串序列中概率最大的那个。我们可以选择暴力的解码策略:穷举搜索,但时间复杂度是指数级的N^{T},显然不可行。
然而,上式不存在已知的高效解法。下面介绍几种实用的近似破解码方法。
1 贪心搜索 (greedy search)
原理:
虽然 p(l|x)
难以有效的计算,但是由于 CTC 的独立性假设,对于某个具体的字符串 π(去 blank 前),确容易计算:
因此,我们放弃寻找使 p(l|x) 最大的字符串,退而寻找一个使 p(π|x) 最大的字符串,即:
其中,
简化后,解码过程(构造 π⋆)变得非常简单(基于独立性假设): 在每个时刻输出概率最大的字符:
Greedy search 是在每一步选择概率最大的输出值,这样就可以得到最终解码的输出序列(如上图例子,最终解码的输出序列l=blank
)。然而,CTC网络的输出序列只对应了搜索空间的一条路径,一个最终标签可对应搜索空间的N条路径,所以概率最大的路径并不等于最终标签的概率最大,即不是最优解(如上图例子,最优解是p(l=b)
而不是p(l=blank))
。
图示:
代码:
def remove_blank(labels, blank=0):
import numpy as np
# 求每一列(即每个时刻)中最大值对应的softmax值
def softmax(logits):
# 注意这里求e的次方时,次方数减去max_value其实不影响结果,因为最后可以化简成教科书上softmax的定义
# 次方数加入减max_value是因为e的x次方与x的极限(x趋于无穷)为无穷,很容易溢出,所以为了计算时不溢出,就加入减max_value项
# 次方数减去max_value后,e的该次方数总是在0到1范围内。
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def remove_blank(labels, blank=0):
new_labels = []
# 合并相同的标签
previous = None
for l in labels:
if l != previous:
new_labels.append(l)
previous = l
# 删除blank
new_labels = [l for l in new_labels if l != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def greedy_decode(y, blank=0):
# 按列取最大值,即每个时刻t上最大值对应的下标
raw_rs = np.argmax(y, axis=1)
# 移除blank,值为0的位置表示这个位置是blank
rs = remove_blank(raw_rs, blank)
return raw_rs, rs
np.random.seed(1111)
y_test = softmax(np.random.random([20, 6]))
label_have_blank, label_no_blank = greedy_decode(y_test)
print(label_have_blank)
print(label_no_blank)
2 束搜索(Beam Search)
贪心搜索的性能非常受限, 这种方法忽略了一个输出可能对应多个对齐结果。很多时候,如果我们能拿到nearbest的路径,后续可以利用其他信息来进一步优化搜索的结果。束搜索能近似找出 top 最优的若干条路径。
原理:
基本原理是通过 t i − 1 t_{i-1} ti−1中beamsize
个序列,每个序列分别连接 t i t_{i} ti中beamsize
个节点,得到 beamsize
个新序列及对应的score,然后按照score从大到小的顺序选出前beamSize
个序列,依次推进。
图示:
假设 beamsize
为2
t=1时:
这个时候只会将两个概率最大的节点放进路径集合中,即有两条路径。
t=2时:
上面的两个路径每个路径都会和下一个时间点的每一项组成新的路径,因此一共有 b e a m s i z e × V = 2 ∗ 3 = 6 beamsize\times V=2*3=6 beamsize×V=2∗3