CTC 解码算法之 prefix beam search

ctc prefix beam search 算法

1520134-20181024234754541-1448772948.png

CTC 网络的输出 net_out 形状为 T×C,其中 T 是时间长度,C 是字符类别数加1(额外的blank)。
CTC 的 beam search 算法维护的不是 K 个路径前缀,而是 K 个标签前缀,但仍需要考虑其背后的路径(路径到标签的多对一关系)。每个时间步,对 K 个前缀进行扩展,用字符表中的字符对已有前缀做扩展,得到新的多个前缀,然后计算这些前缀的概率,从中挑选出概率最大的 K 个保存,不断重复这个过程直到最后一个时间步,然后选出概率最大的一个结果作为最终的标签。

第一种实现

# -*-coding:utf-8-*-
from collections import defaultdict, Counter
from string import ascii_lowercase
import re
import numpy as np


def prefix_beam_search(ctc, lm=None, k=25, alpha=0.30, beta=5, prune=0.001):
    """
    对CTC网络的输出做beam search
    Args:
            ctc (np.ndarray): The CTC网络输出. 2D array形状为(timesteps x alphabet_size)
            lm (func): 语言模型函数. 接收一个字符串做参数,输出一个概率
            k (int): beam宽度. 每个时间步将保存 k 个概率最大的候选前缀
            alpha (float): 语言模型的权重,取值在0到1
            beta (float): 语言模型的惩罚(奖励)项权重. alpha越大,beta应该越大
            prune (float): ctc的每个时间步的输出分布中概率大于prune的才参与前缀扩展

    Retruns:
            string: 返回解码结果
    """

    # 没有提供语言模型则始终返回1
    lm = (lambda l: 1) if lm is None else lm
    # 正则匹配l中所有的单词
    W = (lambda l: re.findall(r'\w+[\s|>]', l))
    # 字母表是小写英文字母及空格、结束符、空白标签
    alphabet = list(ascii_lowercase) + [' ', '>', '%']
    F = ctc.shape[1]
    # 对ctc输出添加一个想象中的时间步0,用于初始化空,手动传个blank进来
    ctc = np.vstack((np.zeros(F), ctc))
    T = ctc.shape[0]

    # 空前缀
    origin = ''
    # 每个时间步下beam里虽然存的是前缀集合,但需要考虑其背后对应的路径集合
    # Pb表示前缀由那些以blank结尾的路径生成的概率,Pnb表示前缀由那些以non-blank结尾的路径生成的概率
    Pb, Pnb = defaultdict(Counter), defaultdict(Counter)
    # 因为手动传blank进来,所以以blank结尾且前缀为空的概率为1
    Pb[0][origin] = 1
    # non-blank结尾的路径生成空前缀,不可能发生
    Pnb[0][origin] = 0
    # A_prev保存当前时间步开始扩展之前所保留的概率最大的候选前缀集合,数目小于等于 k
    A_prev = [origin]
    # 不断的扩展前缀,始终保留概率最大的 k 个
    # 路径空间上的事件定义:
    # A(t, l)代表到t步为止生成l,
    # Ab(t, l)代表到t步为止生成l且末位为blank,
    # Anb(t, l)代表到t步为止生成l且末位为non-blank
    for t in range(1, T):
        # 对当前时间步的字母表分布,选取概率大于prune的,减少运算
        pruned_alphabet = [alphabet[i] for i in np.where(ctc[t] > prune)[0]]
        # 因为同一个时间步A_prev里不同的前缀可能会扩展出相同的新前缀,所以概率用增量式而不是赋值
        for l in A_prev:
            # 当前前缀已经到句末了,不能再扩展
            if len(l) > 0 and l[-1] == '>':
                Pb[t][l] = Pb[t - 1][l]
                Pnb[t][l] = Pnb[t - 1][l]
                continue
            # 每个l都代表着A(t-1, l)事件
            for c in pruned_alphabet:
                c_ix = alphabet.index(c)
                # A(t-1, l)遇到blank,只有一种结果,即Ab(t, l)
                # 计算概率贡献P(t-1, l) * P(blank, t)
                if c == '%':
                    Pb[t][l] += ctc[t][-1] * (Pb[t - 1][l] + Pnb[t - 1][l])
                else:
                    l_plus = l + c
                    # l中有多种路径来源,遇到c后产生多种结果,需要分别计算不同来源
                    # 经过c扩展后得到不同结果的概率,对特定事件的概率做出贡献
                    if len(l) > 0 and c == l[-1]:
                        # A(t-1, l)中来自blank结尾的路径经过c扩展得到l_plus,Anb(t, l_plus)发生,计算概率贡献
                        Pnb[t][l_plus] += ctc[t][c_ix] * Pb[t - 1][l]
                        # A(t-1, l)中来自non-blank结尾的路径经过c扩展,维持l不变,Anb(t, l)发生,计算概率贡献
                        Pnb[t][l] += ctc[t][c_ix] * Pnb[t - 1][l]
                    # c既不是l末元素也不是blank,A(t-1, l)+c只有一种结果即Anb(t, l_plus),计算概率贡献,但是
                    # 计算方式需要根据l当前状态做调整
                    elif len(l.replace(' ', '')) > 0 and c in (' ', '>'):
                        lm_prob = lm(l_plus.strip(' >')) ** alpha
                        Pnb[t][l_plus] += lm_prob * ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l])
                    else:
                        Pnb[t][l_plus] += ctc[t][c_ix] * (Pb[t - 1][l] + Pnb[t - 1][l])
                    # l_plus作为有可能加入A_next的新前缀却没有在上次出现在A_prev中,一种可能是由于
                    # beam width的限制,A(t-1, l_plus)概率没排在前 k,导致没被A_prev收录,但是
                    # 在本次扩展结果中,再次需要考虑l_plus是否能加入A_next了,即计算A(t, l_plus)的
                    # 概率排名,但是由于l_plus不在A_prev中,只能考虑新扩展(l + c)的方式获得的概率贡献。
                    # 假设没有beam width的限制,l_plus之前是在A_prev中的,那就还有两种方式来在本次获得概率贡献,
                    # 一种是本次扩展一个blank,一种是本次扩展一个重复字符,这算是beam width限制下的一种查漏。
                    # 对大多数情形前一个时间步就生成了更长的l_plus是不会发生的,所以这一项大多时候不起作用。
                    if l_plus not in A_prev:
                        # A(t-1, l_plus) + blank得Ab(t, l_plus)
                        Pb[t][l_plus] += ctc[t][-1] * (Pb[t - 1][l_plus] + Pnb[t - 1][l_plus])
                        # Anb(t-1, l_plus) + c得Anb(t, l_plus)
                        Pnb[t][l_plus] += ctc[t][c_ix] * Pnb[t - 1][l_plus]
        A_next = Pb[t] + Pnb[t]
        sorter = (lambda l: A_next[l] * (len(W(l)) + 1) ** beta)
        A_prev = sorted(A_next, key=sorter, reverse=True)[:k]
    return A_prev[0].strip('>')

最后用的实现

# 负无穷大
NEG_INF = float('-inf')

def logsumexp(*args):
    '''
    log概率求和,即计算log(a + b)
    使用的公式是:
    log(a + b) = log(a) + log(1 + exp(log(b) - log(a)))
    '''
    # args中都是log scale的概率,负无穷代表真实概率为0
    if all(a == NEG_INF for a in args):
        return NEG_INF
    # 用序列最大值当公式中 a
    a_max = max(args)
    lsp = np.log(sum(np.exp(a - a_max) for a in args))
    return a_max + lsp

def prune_vocab(prob, vocab_list, accumulation, max_num=50):
    """ 词汇表裁剪
    留下累积概率超过accumulation且最多不超过max_num的类别
    """
    assert prob.shape[0] == len(vocab_list)
    assert accumulation < 1.0
    # 累积概率裁剪
    prob = np.exp(prob)
    indices = np.argsort(prob)[::-1]
    prob_sorted = np.array([prob[ii] for ii in indices])
    prob_accumulated = np.add.accumulate(prob_sorted)
    index = np.where(prob_accumulated >= accumulation)[0][0]
    # 裁剪过后最多不超过max_num个候选
    index = min(index, max_num - 1)
    part_indices = indices[:index + 1]
    # 必须把blank的索引加进来,否则beam_ext有可能全是新的beam,已有的高概率beam会被丢弃
    # (此时只有重复元素折叠的情形能保留已有beam,+blank保留beam的情形被丢弃了)
    blank_idx = len(vocab_list) - 1
    if blank_idx not in part_indices:
        part_indices = np.append(part_indices, blank_idx)
    part_vocab_list = [vocab_list[ii] for ii in part_indices]
    return part_indices, part_vocab_list

def beam_search_decode(ctc, vocab_list, str2idx_table,
                       beam_width, alpha, beta,
                       accumulation, lm_func=None):
    '''
    ctc (np.ndarray) : CTC网络的输出,形状为(time_steps x alphabet_size)
    vocab_list (list) : 类别字符列表,即alphabet
    beam_width (int) : beam search保存最大概率候选的数目
    alpha (float) : 语言模型的条件概率权重,取值0到1
    beta (float) : 语言模型的序列长度奖励权重,避免长序列概率过小,alpha越大beta也应该越大
    accumulation (float) : 类别列表中累积概率超过这个数的类别才参与扩展
    lm_func (function) : 计算某个前缀的条件概率
    '''
    # 事件定义:
    # 用A(t, l)代表 Path[1:t] -> l,有下面两种可能
    # 用A_b(t, l)代表 Path[1:t] -> l 且 l 末位为blank
    # 用A_nb(t, l)代表 Path[1:t] -> l 且 l 末位为non-blank

    # 不使用语言模型时始终返回真实概率1.0,对数概率0.0
    lm_func = (lambda x: 0.0) if lm_func is None else lm_func
    # pb[t][l]事件A_b(t, l)的概率,log概率初始化为负无穷
    pb = defaultdict(lambda: defaultdict(lambda: float('-inf')))
    # pnb[t][l]事件A_nb(t, l)的概率
    pnb = defaultdict(lambda: defaultdict(lambda: float('-inf')))
    # 事件A(t, l)的概率为二者之和pb[t][l] + pnb[t][l]

    # 路径扩展除了vocab_list之外还包含blank,用%表示
    vocab_list = vocab_list + ['%']
    # 为ctc扩展一个想象出来的时间0
    num_classes = ctc.shape[1]
    ctc = np.vstack((np.zeros((num_classes,), dtype=ctc.dtype), ctc))
    ctc = np.log(ctc)
    # 比实际时间多一个
    T = ctc.shape[0]

    empty_prefix = '^'
    # 真实概率1.0
    pb[0][empty_prefix] = 0.0
    # 真实概率0.0
    pnb[0][empty_prefix] = NEG_INF
    # beams最多存放beam_width个概率最大的(真)前缀,初始化为空前缀
    beams = [empty_prefix]

    for t in range(1, T):
        # 同一个时间步下beams里的不同前缀可能会扩展出相同的新前缀,所以下面在计算
        # 某个前缀的概率时不用赋值而用增量,表示两种来源都会对这一事件有概率贡献
        # 例子: beams里同时有BO和BOX,BO通过添加X扩展为BOX,而BOX扩展X维持不变
        # A(t-1,BO)+X和A(t-1,BOX)+X此二者都会促成事件A_nb(t,BOX)但是来源不同是互斥的。
        beams_ext = []
        ppstart = time.time()
        part_indices, part_vocab_list = prune_vocab(ctc[t], vocab_list, accumulation)
        if t == 1:
            print('prune vocab cost {}'.format(time.time() - ppstart))
        for l in beams:
            # 还没扩展之前,时间还是t-1,每一个l都表示事件A(t-1, l)
            # 遍历vocab_list扩展出新事件,讨论新事件的概率
            # for c_idx, c in enumerate(vocab_list):
            for c_idx, c in zip(part_indices, part_vocab_list):
                if c == '%':
                    # A(t-1, l) + blank,前缀不变、时间+1、末位为blank,即得事件A_b(t, l)
                    # 计算这个事件的概率
                    # p_b[t][l] += (p_b[t - 1][l] + p_nb[t - 1][l]) * ctc[t, -1]
                    ll_start = time.time()
                    pb[t][l] = logsumexp(pb[t][l], pb[t - 1][l] + ctc[t, -1], pnb[t - 1][l] + ctc[t, -1])
                    # A(t, l)有发生的可能了,将l添加到beams_ext
                    if l not in beams_ext:
                        beams_ext.append(l)
                    if t == 1:
                        print('log_sum cost {}'.format(time.time() - ll_start))
                else:
                    # 以下c都为non-blank了
                    # 仅当覆盖增加时前缀变化需要应用语言模型
                    # c_idx = str2idx_table(c)
                    l_plus = l + c
                    lm_prob = alpha * lm_func(l_plus)
                    if len(l) > 0 and c == l[-1]:
                        # A(t-1, l)中的两种子事件A_b(t-1, l)和A_nb(t-1, l)经过重复末元素扩展
                        # 会得到不同的结果,得到两个事件,分别计算之
                        # A_b(t-1, l) + l[-1]覆盖增加得事件A_nb(t, l_plus)
                        # p_nb[t][l_plus] += p_b[t - 1][l] * ctc[t, c_idx] * lm_func(l_plus) ** alpha
                        pnb[t][l_plus] = logsumexp(pnb[t][l_plus],
                                                   pb[t - 1][l] + ctc[t, c_idx] + lm_prob)
                        # A_nb(t-1, l) + l[-1]发生重复元素折叠得事件A_nb(t, l)
                        # p_nb[t][l] += p_nb[t - 1][l] * ctc[t, c_idx]
                        pnb[t][l] = logsumexp(pnb[t][l],
                                              pnb[t - 1][l] + ctc[t, c_idx])
                    else:
                        # c既非重复末元素也非blank,A(t-1, l) + c覆盖增加且末位非blank,
                        # 只有一个事件会发生——A_nb(t, l_plus),计算概率
                        # p_nb[t][l_plus] += (p_b[t - 1][l] + p_nb[t - 1][l]) * ctc[t, c_idx] * lm_func(l_plus) ** alpha
                        pnb[t][l_plus] = logsumexp(pnb[t][l_plus],
                                                   pb[t - 1][l] + ctc[t, c_idx] + lm_prob,
                                                   pnb[t - 1][l] + ctc[t, c_idx] + lm_prob)
                    # 不管那种情形A(t, l_plus)有可能发生了,将l_plus添加到beams_ext
                    if l_plus not in beams_ext:
                        beams_ext.append(l_plus)
        print('total {} candidates in beams_ext'.format(len(beams_ext)))
        # 将beams_ext中的前缀按概率排序
        # prefix_probs = [(p_b[t][ll] + p_nb[t][ll]) * (len(ll) + 1) ** beta for ll in beams_ext]
        # 概率
        prefix_probs = [logsumexp(pb[t][ll], pnb[t][ll]) for ll in beams_ext]
        # 排序分数,加入了序列长度奖励
        sort_score = [logsumexp(pb[t][ll] + beta * (len(ll) + 1),
                                pnb[t][ll] + beta * (len(ll) + 1))
                      for ll in beams_ext]
        indices = sorted(range(len(sort_score)), key=lambda k: sort_score[k])[::-1][:beam_width]
        beams = [beams_ext[ii] for ii in indices]
        beams_probs = [prefix_probs[ii] for ii in indices]
        beams_scores = [sort_score[ii] for ii in indices]

        for ii in range(len(beams)):
            print('p(A({}, \'{}\')) = {} score = {}'.
                  format(t, beams[ii], beams_probs[ii], beams_scores[ii]))
        print('-' * 64)
    decoded = beams[0]
    return decoded




转载于:https://www.cnblogs.com/cookcooller/p/9846999.html

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值