CTC Loss(三)

一、greedy search

二、beam search

三、prefix beam search

完整代码如下

"""
Author: Awni Hannun
This is an example CTC decoder written in Python. The code is
intended to be a simple example and is not designed to be
especially efficient.
The algorithm is a prefix beam search for a model trained
with the CTC loss function.
For more details checkout either of these references:
  https://distill.pub/2017/ctc/#inference
  https://arxiv.org/abs/1408.2873
"""

import numpy as np
import math
import collections

NEG_INF = -float("inf")

def make_new_beam():
  fn = lambda : (NEG_INF, NEG_INF)
  return collections.defaultdict(fn)

def logsumexp(*args):
  """
  Stable log sum exp.
  """
  if all(a == NEG_INF for a in args):
      return NEG_INF
  a_max = max(args)
  lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
  return a_max + lsp

def decode(probs, beam_size=100, blank=0):
  """
  Performs inference for the given output probabilities.
  Arguments:
      probs: The output probabilities (e.g. post-softmax) for each
        time step. Should be an array of shape (time x output dim).
      beam_size (int): Size of the beam to use during inference.
      blank (int): Index of the CTC blank label.
  Returns the output label sequence and the corresponding negative
  log-likelihood estimated by the decoder.
  """
  T, S = probs.shape
  probs = np.log(probs)

  # Elements in the beam are (prefix, (p_blank, p_no_blank))
  # Initialize the beam with the empty sequence, a probability of
  # 1 for ending in blank and zero for ending in non-blank
  # (in log space).
  beam = [(tuple(), (0.0, NEG_INF))]

  for t in range(T): # Loop over time

    # A default dictionary to store the next step candidates.
    next_beam = make_new_beam()

    for s in range(S): # Loop over vocab
      p = probs[t, s]

      # The variables p_b and p_nb are respectively the
      # probabilities for the prefix given that it ends in a
      # blank and does not end in a blank at this time step.
      for prefix, (p_b, p_nb) in beam: # Loop over beam

        # If we propose a blank the prefix doesn't change.
        # Only the probability of ending in blank gets updated.
        if s == blank:
          n_p_b, n_p_nb = next_beam[prefix]
          n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
          next_beam[prefix] = (n_p_b, n_p_nb)
          continue

        # Extend the prefix by the new character s and add it to
        # the beam. Only the probability of not ending in blank
        # gets updated.
        end_t = prefix[-1] if prefix else None
        n_prefix = prefix + (s,)
        n_p_b, n_p_nb = next_beam[n_prefix]
        if s != end_t:
          n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
        else:
          # We don't include the previous probability of not ending
          # in blank (p_nb) if s is repeated at the end. The CTC
          # algorithm merges characters not separated by a blank.
          n_p_nb = logsumexp(n_p_nb, p_b + p)
          
        # *NB* this would be a good place to include an LM score.
        next_beam[n_prefix] = (n_p_b, n_p_nb)

        # If s is repeated at the end we also update the unchanged
        # prefix. This is the merging case.
        if s == end_t:
          n_p_b, n_p_nb = next_beam[prefix]
          n_p_nb = logsumexp(n_p_nb, p_nb + p)
          next_beam[prefix] = (n_p_b, n_p_nb)

    # Sort and trim the beam before moving on to the
    # next time-step.
    beam = sorted(next_beam.items(),
            key=lambda x : logsumexp(*x[1]),
            reverse=True)
    beam = beam[:beam_size]

  best = beam[0]
  return best[0], -logsumexp(*best[1])

if __name__ == "__main__":
  np.random.seed(3)

  time = 50
  output_dim = 20

  probs = np.random.rand(time, output_dim)
  probs = probs / np.sum(probs, axis=1, keepdims=True)

  labels, score = decode(probs)
  print("Score {:.3f}".format(score))

首先beam是一个存放着形如 (prefix, (p_blank, p_no_blank))元素的列表,prefix表示到当前timestep为止的一条路径,注意这条路径是经过\beta变换的,即去除连续重复字符和blank。p_blank和p_no_blank分别表示经过\beta变换前所有以blank结尾的路径概率和和以非blank结尾的所有路径概率和,例如在T=3时,-ab-和abbb两条路径经过\beta变换都转换成ab,也就是当前的prefix,但当T=4时取字符b,-ab-b就变成了abb,而abbbb转换后还是ab,因此这里保存两种情况的概率以便计算。

一开始beam里第一个元素初始化成 (tuple(), (0, NEG_INF)),即prefix为空,NEG_INF是负无穷,这里的概率是转换到对数域后的,经过exp还原回去变成p_blank=1, p_no_blank=0,即以blank结尾的概率为1,以非blank结尾的概率为0。

最外层循环遍历T,每到一个新的timestep,调用make_new_beam函数生成一个next_beam,beam里存放的是遍历完上一个timestep所有字符后的路径和概率,next_beam存放的是遍历完当前timestep所有字符更新后的路径和概率,next_beam初始化为空,传入任何prefix返回n_p_b和n_p_nb都为NEG_INF,即概率为0,n代表new,和beam里的p_b、p_nb对应。

函数logsumexp计算的是路径的和,具体先通过exp将路径的对数域概率还原,然后将所有路径概率相加,再通过log转换回对数域。其中每个概率先减去a_max最后结果再加上a_max可参考CTC Loss(二)中对数域优化(三)。

if s == blank:
    n_p_b, n_p_nb = next_beam[prefix]
    n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
    next_beam[prefix] = (n_p_b, n_p_nb)

当遍历到当前的字符s是blank时,只更新n_p_b,p_b+p和p_nb+p实际上是概率相乘,因为lna+lnb=ln(ab),p_b和p_nb分别是前一时刻prefix以blank结尾和以非blank字符结尾的概率,乘上当前timestep字符blank的概率,prefix没有变化,以blank结尾的概率n_p_b得到更新。

当遍历到当前的字符s不是blank时,还要分s和prefix中的最后一个字符是否一样两种情况。当不一样时,prefix要加上当前字符s,更新以非blank结尾的路径概率,代码如下

end_t = prefix[-1] if prefix else None
n_prefix = prefix + (s,)
n_p_b, n_p_nb = next_beam[n_prefix]
if s != end_t:
    n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
next_beam[n_prefix] = (n_p_b, n_p_nb)

当一样时, 当更新以非blank结尾的路径概率n_p_nb时,不包括前一时刻以非blank结尾的路径概率p_nb,因为我们现在更新的概率是prefix+s的概率,连续重复字符按规则要删去。例如假设prefix为ab,-ab-和abbb分别是以blank和非blank结尾并且经过\beta变换转换为prefix的众多路径中的一条,当前字符s='b',则prefix+s=abb,-ab-b会转换为abb,但abbbb转换后还是ab,因为只更新以blank结尾的概率p_b。代码如下

else:
    # We don't include the previous probability of not ending
    # in blank (p_nb) if s is repeated at the end. The CTC
    # algorithm merges characters not separated by a blank.
    n_p_nb = logsumexp(n_p_nb, p_b + p)
next_beam[n_prefix] = (n_p_b, n_p_nb)

上面的例子中abbb+s=abbbb,转换后还是ab,因此虽然prefix+s的概率不更新,但是需要更新原prefix的概率。这时从next_beam里取的是原始prefix的概率,并且更新n_p_nb时不包括以blank结尾的概率p_b,因为p_b更新的概率是包含在prefix+s中的。代码如下

if s == end_t:
    n_p_b, n_p_nb = next_beam[prefix]
    n_p_nb = logsumexp(n_p_nb, p_nb + p)
    next_beam[prefix] = (n_p_b, n_p_nb)

遍历timestep时,每次更新到一个新的t,创建一个空的next_beam,在当前t中,遍历每个字符s以及前一时刻挑选出的路径beam,在当前时刻t更新完时,对next_beam种的所有路径按概率进行排序,取概率最大的beam_size条路径继续更新。注意对概率排序时,每个prefix的概率是以blank结尾和非blank结尾概率的和。代码如下

# Sort and trim the beam before moving on to the
# next time-step.
beam = sorted(next_beam.items(),
              key=lambda x: logsumexp(*x[1]),
              reverse=True)
beam = beam[:beam_size]

参考 

https://distill.pub/2017/ctc/

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值