CTC : prefix beam search decode

        CTC 最简单的decode方式当然是最近拿每个frame的最大概率的token,但实际应用中这种方法字错率会颇高,且无法和语言模型结合。

        要与语言模型结合,必须有多个candidate,但也不能够穷尽每个frame每个token的组合,故有beam search。但beam search的candidate会有很多相同的部分,相同的部分应把他们置信度加起来,否则会可能因为单一路径置信度偏小而错误排优,所以有了Prefix beam search。

       如下beam search 三条路径规整化变成右边三个最终序列,

                top1:aadddbbcccck       ---> adbck  = p(adbc) * p(k)

                top2:aaabbbccdddk       ---> abcdk  = p1(abcd) * p(k)

                top3:aaaabbcddddk       ---> abcdk  =  p2(abcd) * p(k)

       可能由于置信度都差不多,但通过bean search , 第一条是最优,但实际不考虑严格对齐,只求规整化序列结果,第二,三条 有相同的前缀abcd, 这时概率p1(abcd) , p2(abcd) 都小于 p(adbc), 但p(abcd) = p1(abcd)  + p2(abcd) 其实大于p(adbc),所以他们应合并为 abcdk 最优路径。

         CTC 的 prefix beam search 算法维护的不是beam 个路径前缀,而是 beam 个标签前缀(存在不同的路径映射到相同规整化前缀),但仍需要考虑其之后的路径。每个时间步 t ,对 beam 个前缀进行扩展,用字符表中的字符对已有前缀做扩展,得到新的多个前缀,然后计算这些前缀的概率,从中挑选出概率最大的 beam 个保存,不断重复这个过程直到最后一个时间步,然后选出概率最大的一个结果作为最终的标签。

        由于前缀是以 blank 还是 non blank (记为 _ ) 结尾符号,对 t 时刻要扩展的符号是有变化影响的,例如前缀序列为*a, 它可以是  *a 和 *a_ 规整成的 , 我们分别记其概率分别为 pnb(*a) 和 pb(*a)。这时有如下几种扩展符号情况:

在 t 时刻对beam 个要扩展的符号分别与当前存在的前缀进行扩展:

1 、t 时刻符号扩展 _ :

*a (*a)    + _ 后 变成 *a_ (*a), 而pnb(a) 没改变
 *a_ (*a) + _ 后 还是为 *a_ (*a), 所以 概率更新 n_pb(*a) += (pb(*a) * p(_)) + (pnb(*a) * p(_))


         注: pnb, pb 当前某个前缀的分别以非blank 和 blank 的概率;

                 n_pnb, n_pb 当前接下来要更新的某个前缀的分别以非blank 和 blank 的概率;

                 n_pb += 用累加是 n_pb 是t时刻保存一样路径的概率和, 同理 n_pnb += ;

                 (a) 括号里的序列表示规整化的序列;

                 *  表示任何规整化序列。

2、t时刻符号扩展符号 a 与当前规整前缀(*a)最后符号相同:

*a  (*a)  + a 后 还是变成 (*a), 从而更新 n_pnb(*a) += pnb(*a) * p(a)
*a_ (*a) + a 后 变成(*aa), 是另一个路径, 则新改路径 n_pnb(*aa) += pb(*a) * p(a)

3、t时刻符号扩展符号 c,既不是 _ , 也不与前缀最后一个符号相同:

只有一种情况,就是扩展为另一个路径,此时更新这个路径概率,
*a + c 后变后 (*ac), *a_ + c 后也变为 (*ac) , 所以 n_pnb(*ac) += pb(*a) * p(c) + pnb(*a) * p(c)

        最后以wenet 中算法为例解析:


        # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
        # cur_hyps (prefix ,(pb, pnb))用于当前已经存在的前缀 prefix, 初始化为 空
        #                             pb 对应前缀 prefix 以 blank 结尾的概率, 初始化为 1 (空前缀等价 blank)
        #                             pnb 对应前缀 prefix 以 非 blank 结尾的概率, 初始化为 0 (空前缀 没有 非 blank)
        cur_hyps = [(tuple(), (0.0, -float('inf')))]
        # 2. CTC beam search step by step
        for t in range(0, maxlen):
          
            logp = ctc_probs[t]  # (vocab_size,)
            # key: prefix, value (pb, pnb), default value(-inf, -inf)
            # next_hpys 保存 t 时刻,扩展的路径
            next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
            # 2.1 First beam prune: select topk best
            # 当前 t 时刻的 beam 个 top 输出
            top_k_logp, top_k_index = logp.topk(beam_size)  # (beam_size,)
            for s in top_k_index:
                s = s.item()
                ps = logp[s].item()
                for prefix, (pb, pnb) in cur_hyps:
                    last = prefix[-1] if len(prefix) > 0 else None
                    if s == 0:  # blank
                        # 扩展符号为 blank, 则当前前缀扩展后还是一样,有两种情况
                        # 情况1,前缀为结尾非 blank : 如 *a (a) + _ 后 变成 *a_ (*a), 而pnb 没改变
                        # 情况2,前缀为结尾 blank : 如 *a_ (*a) + _ 后 还是为 *a_ (*a), 所以 概率更新 n_pb(*a) += (pb(*a) * ps) + (pnb((*a) * ps)
                        # 注: n_pb += 用累加是 n_pb 是时刻保存一样路径的概率和, 同理 n_pnb += 。
                        n_pb, n_pnb = next_hyps[prefix]
                        n_pb = log_add([n_pb, pb + ps, pnb + ps])
                        next_hyps[prefix] = (n_pb, n_pnb)
                    elif s == last:
                        # 扩展符号非 blank 且与前缀最后字符相同,有两种情况
                        # 情况1,前缀结尾是非 blank : 如 *a + a 后 变成 (*a) , 从而更新pnb(*a)  n_pnb += pnb * ps
                        # 情况2,前缀结尾是 blank : 如 *a_ + a 后 变城 (*aa), 是另一个路径, 则新改路径 n_pnb(*aa) += pb(*a) * ps
                        #  Update *ss -> *s;
                        n_pb, n_pnb = next_hyps[prefix]
                        n_pnb = log_add([n_pnb, pnb + ps])
                        next_hyps[prefix] = (n_pb, n_pnb)
                        # Update *s-s -> *ss, - is for blank
                        n_prefix = prefix + (s, )
                        n_pb, n_pnb = next_hyps[n_prefix]
                        n_pnb = log_add([n_pnb, pb + ps])
                        next_hyps[n_prefix] = (n_pb, n_pnb)
                    else:
                        # 扩展符号非 blank 且与前缀最后字符不同,则只有一种情况,就是扩展为另一个路径,此时更新这个路径概率
                        # 如 *a + c 后变后 (*ac) , *a_ + c 后也变为 (*ac) , 所以 n_pnb(*ac) += pb(*a) * ps + pnb(*a) * ps
                        n_prefix = prefix + (s, )
                        n_pb, n_pnb = next_hyps[n_prefix]
                        n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
                        next_hyps[n_prefix] = (n_pb, n_pnb)
            # t 时刻重新更新保存 beam 个最优 前缀。
            next_hyps = sorted(next_hyps.items(),
                               key=lambda x: log_add(list(x[1])),
                               reverse=True)
            cur_hyps = next_hyps[:beam_size]

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值