【Learning Notes】CTC 原理及实现

CTC( Connectionist Temporal Classification,连接时序分类)是一种用于序列建模的工具,其核心是定义了特殊的目标函数/优化准则[1]。

jupyter notebook 版见 repo.

1. 算法

这里大体根据 Alex Graves 的开山之作[1],讨论 CTC 的算法原理,并基于 numpy 从零实现 CTC 的推理及训练算法。

1.1 序列问题形式化。

序列问题可以形式化为如下函数:

Nw:(Rm)T(Rn)T N w : ( R m ) T → ( R n ) T

其中,序列目标为字符串(词表大小为 n n ),即 N w 输出为 n n 维多项概率分布(e.g. 经过 softmax 处理)。

网络输出为: y = N w ,其中, ytk y k t t t 表示时刻第 k 项的概率。


图1. 序列建模【src

虽然并没为限定 Nw N w 具体形式,下面为假设其了某种神经网络(e.g. RNN)。
下面代码示例 toy Nw N w

import numpy as np

np.random.seed(1111)

T, V = 12, 5
m, n = 6, V

x = np.random.random([T, m])  # T x m
w = np.random.random([m, n])  # weights, m x n

def softmax(logits):
    max_value = np.max(logits, axis=1, keepdims=True)
    exp = np.exp(logits - max_value)
    exp_sum = np.sum(exp, axis=1, keepdims=True)
    dist = exp / exp_sum
    return dist

def toy_nw(x):
    y = np.matmul(x, w)  # T x n 
    y = softmax(y)
    return y

y = toy_nw(x)
print(y)
print(y.sum(1, keepdims=True))
[[ 0.24654511  0.18837589  0.16937668  0.16757465  0.22812766]
 [ 0.25443629  0.14992236  0.22945293  0.17240658  0.19378184]
 [ 0.24134404  0.17179604  0.23572466  0.12994237  0.22119288]
 [ 0.27216255  0.13054313  0.2679252   0.14184499  0.18752413]
 [ 0.32558002  0.13485564  0.25228604  0.09743785  0.18984045]
 [ 0.23855586  0.14800386  0.23100255  0.17158135  0.21085638]
 [ 0.38534786  0.11524603  0.18220093  0.14617864  0.17102655]
 [ 0.21867406  0.18511892  0.21305488  0.16472572  0.21842642]
 [ 0.29856607  0.13646801  0.27196606  0.11562552  0.17737434]
 [ 0.242347    0.14102063  0.21716951  0.2355229   0.16393996]
 [ 0.26597326  0.10009752  0.23362892  0.24560198  0.15469832]
 [ 0.23337289  0.11918746  0.28540761  0.20197928  0.16005275]]
[[ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [ 1.]]

1.2 align-free 变长映射

上面的形式是输入和输出的一对一的映射。序列学习任务一般而言是多对多的映射关系(如语音识别中,上百帧输出可能仅对应若干音节或字符,并且每个输入和输出之间,也没有清楚的对应关系)。CTC 通过引入一个特殊的 blank 字符(用 % 表示),解决多对一映射问题。

扩展原始词表 L L L = L { blank } 。对输出字符串,定义操作 B B :1)合并连续的相同符号;2)去掉 blank 字符。

例如,对于 “aa%bb%%cc”,应用 B B ,则实际上代表的是字符串 “abc”。同理“%a%b%cc%” 也同样代表 “abc”。

B(aa%bb%%cc)=B(%a%b%cc%)=abc B ( a a % b b % % c c ) = B ( % a % b % c c % ) = a b c

通过引入blank 及 B B ,可以实现了变长的映射。

LTLT L ′ T → L ≤ T

因为这个原因,CTC 只能建模输出长度小于输入长度的序列问题。

1.3 似然计算

和大多数有监督学习一样,CTC 使用最大似然标准进行训练。

给定输入 x x ,输出 l 的条件概率为:

p(l|x)=πB1(l)p(π|x) p ( l | x ) = ∑ π ∈ B − 1 ( l ) p ( π | x )

其中, B1(l) B − 1 ( l ) 表示了长度为 T T 且示经过 B 结果为 l l 字符串的集合。

CTC 假设输出的概率是(相对于输入)条件独立的,因此有:

p ( π | x ) = y π t t , π L T

然而,直接按上式我们没有办理有效的计算似然值。下面用动态规划解决似然的计算及梯度计算, 涉及前向算法和后向算法。

1.4 前向算法

在前向及后向计算中,CTC 需要将输出字符串进行扩展。具体的, (a1,,am) ( a 1 , ⋯ , a m ) 每个字符之间及首尾分别插入 blank,即扩展为 (%,a1,%,a2,%,,%,am,%) ( % , a 1 , % , a 2 , % , ⋯ , % , a m , % ) 。下面的 l l 为原始字符串, l 指为扩展后的字符串。

定义

αt(s)=defπNT:B(π1:t)=l1:st=1tytπ α t ( s ) = d e f ∑ π ∈ N T : B ( π 1 : t ) = l 1 : s ∏ t ′ = 1 t y π ′ t

显然有,

α1(1)=y1b,α1(2)=y1l1,α1(s)=0,s>2(1)(2)(3) (1) α 1 ( 1 ) = y b 1 , (2) α 1 ( 2 ) = y l 1 1 , (3) α 1 ( s ) = 0 , ∀ s > 2

根据 α α 的定义,有如下递归关系:
αt(s)={ (αt1(s)+αt1(s1))ytls,   if ls=b or ls2=ls(αt1(s)+αt1(s1)+αt1(s2))ytls  otherwise α t ( s ) = { ( α t − 1 ( s ) + α t − 1 ( s − 1 ) ) y l s ′ t ,       i f   l s ′ = b   o r   l s − 2 ′ = l s ′ ( α t − 1 ( s ) + α t − 1 ( s − 1 ) + α t − 1 ( s − 2 ) ) y l s ′ t     o t h e r w i s e

1.4.1 Case 2

递归公式中 case 2 是一般的情形。如图所示, t t 时刻字符为 s 为 blank 时,它可能由于两种情况扩展而来:1)重复上一字符,即上个字符也是 a,2)字符发生转换,即上个字符是非 a 的字符。第二种情况又分为两种情形,2.1)上一字符是 blank;2.2)a 由非 blank 字符直接跳转而来( B B ) 操作中, blank 最终会被去掉,因此 blank 并不是必须的)。

图2. 前向算法 Case 2 示例【src

1.4.2 Case 1

递归公式 case 1 是特殊的情形。
如图所示, t t 时刻字符为

  • 56
    点赞
  • 142
    收藏
    觉得还不错? 一键收藏
  • 21
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值