CTC( Connectionist Temporal Classification,连接时序分类)是一种用于序列建模的工具,其核心是定义了特殊的目标函数/优化准则[1]。
jupyter notebook 版见 repo.
1. 算法
这里大体根据 Alex Graves 的开山之作[1],讨论 CTC 的算法原理,并基于 numpy 从零实现 CTC 的推理及训练算法。
1.1 序列问题形式化。
序列问题可以形式化为如下函数:
其中,序列目标为字符串(词表大小为 n n ),即 输出为 n n 维多项概率分布(e.g. 经过 softmax 处理)。
网络输出为: ,其中, ytk y k t t t 表示时刻第 项的概率。
图1. 序列建模【src】
虽然并没为限定 Nw N w 具体形式,下面为假设其了某种神经网络(e.g. RNN)。
下面代码示例 toy Nw N w :
import numpy as np
np.random.seed(1111)
T, V = 12, 5
m, n = 6, V
x = np.random.random([T, m]) # T x m
w = np.random.random([m, n]) # weights, m x n
def softmax(logits):
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def toy_nw(x):
y = np.matmul(x, w) # T x n
y = softmax(y)
return y
y = toy_nw(x)
print(y)
print(y.sum(1, keepdims=True))
[[ 0.24654511 0.18837589 0.16937668 0.16757465 0.22812766]
[ 0.25443629 0.14992236 0.22945293 0.17240658 0.19378184]
[ 0.24134404 0.17179604 0.23572466 0.12994237 0.22119288]
[ 0.27216255 0.13054313 0.2679252 0.14184499 0.18752413]
[ 0.32558002 0.13485564 0.25228604 0.09743785 0.18984045]
[ 0.23855586 0.14800386 0.23100255 0.17158135 0.21085638]
[ 0.38534786 0.11524603 0.18220093 0.14617864 0.17102655]
[ 0.21867406 0.18511892 0.21305488 0.16472572 0.21842642]
[ 0.29856607 0.13646801 0.27196606 0.11562552 0.17737434]
[ 0.242347 0.14102063 0.21716951 0.2355229 0.16393996]
[ 0.26597326 0.10009752 0.23362892 0.24560198 0.15469832]
[ 0.23337289 0.11918746 0.28540761 0.20197928 0.16005275]]
[[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]
[ 1.]]
1.2 align-free 变长映射
上面的形式是输入和输出的一对一的映射。序列学习任务一般而言是多对多的映射关系(如语音识别中,上百帧输出可能仅对应若干音节或字符,并且每个输入和输出之间,也没有清楚的对应关系)。CTC 通过引入一个特殊的 blank 字符(用 % 表示),解决多对一映射问题。
扩展原始词表 L L 为 。对输出字符串,定义操作 B B :1)合并连续的相同符号;2)去掉 blank 字符。
例如,对于 “aa%bb%%cc”,应用 B B ,则实际上代表的是字符串 “abc”。同理“%a%b%cc%” 也同样代表 “abc”。
通过引入blank 及 B B ,可以实现了变长的映射。
因为这个原因,CTC 只能建模输出长度小于输入长度的序列问题。
1.3 似然计算
和大多数有监督学习一样,CTC 使用最大似然标准进行训练。
给定输入 x x ,输出
的条件概率为:
其中, B−1(l) B − 1 ( l ) 表示了长度为 T T 且示经过 结果为 l l 字符串的集合。
CTC 假设输出的概率是(相对于输入)条件独立的,因此有:
然而,直接按上式我们没有办理有效的计算似然值。下面用动态规划解决似然的计算及梯度计算, 涉及前向算法和后向算法。
1.4 前向算法
在前向及后向计算中,CTC 需要将输出字符串进行扩展。具体的, (a1,⋯,am) ( a 1 , ⋯ , a m ) 每个字符之间及首尾分别插入 blank,即扩展为 (%,a1,%,a2,%,⋯,%,am,%) ( % , a 1 , % , a 2 , % , ⋯ , % , a m , % ) 。下面的 l l 为原始字符串, 指为扩展后的字符串。
定义
显然有,
根据 α α 的定义,有如下递归关系:
1.4.1 Case 2
递归公式中 case 2 是一般的情形。如图所示, t t 时刻字符为
为 blank 时,它可能由于两种情况扩展而来:1)重复上一字符,即上个字符也是 a,2)字符发生转换,即上个字符是非 a 的字符。第二种情况又分为两种情形,2.1)上一字符是 blank;2.2)a 由非 blank 字符直接跳转而来( B B ) 操作中, blank 最终会被去掉,因此 blank 并不是必须的)。
图2. 前向算法 Case 2 示例【src】
1.4.2 Case 1
递归公式 case 1 是特殊的情形。
如图所示, t t 时刻字符为