Connectionist Temporal Classification
https://sunnycat2013.gitbooks.io/blogs/content/posts/ctc/learning-ctc.html
因为最近做了一些用连续标签做文字识别标签任务的工作,对 ctc 有了一些了解,在此记录一下。
在学习 CTC 的时候,也看了不少博客,但是我觉得讲的最好的还是原论文 Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks 解释的最清楚。
对于没接触过这个概念的人,可能加一些例子会更好理解一些。
我就来加一些例子。
背景知识
用实例来说,我们在做 ocr 工作时,我们希望给一行文字的图片让机器识别出来这个图片里面的文字。
语音识别任务中,给了一段语音片段,我们希望能把这段语音识别成可编辑的文字。
但是,在对每个片段进行分类模型训练之前,需要对每个训练样本进行切割标注。
这是项非常繁琐的工作非常不利于模型的训练。
下面是个对文字进行标注的工具,大家可以看一下。如果我们在做文字识别工作时,对每个文字都要明确标出这个字在图片中的位置、高、宽,这将会是一个多么巨大的工作量。
RNNs 在序列学习任务中有着优越的性能,但是它也有一些缺点。
如,在上面描述的那种对输入模型的数据需要预处理的缺点。同时,对 RNNs 的序列也需要一定的整合才能得到最终的预测序列。
而 CTC 解决了对输入序列的单个词的切分和对输入序列的整合工作。
RNN 的输出
输入序列的切分与标注上面已经举了一个例子,现在举一个输出序列整合的例子。
我们现在有一个图片的输入 。
假设这个图片中每个红都作为 RNN 的一步输入,那么(如果这个模型训练的还不错的话)它的输出序列应该是 hheelloo
。
但是,我们知道 RNN 每一步的输出其实都是一组概率分布,p(l|x), l \in Alphebatp(l∣x),l∈Alphebat。
如,对第一个矩形框的输出概率可能是 p(l = 'h' | x) = 0.5, p(l = 'm' | x) = 0.3 \cdotsp(l=′h′∣x)=0.5,p(l=′m′∣x)=0.3⋯
时序分类(Temporal Classification)
先给几个定义。
符号 | 解释 |
---|---|
LL | 字母表 Alphebat |
(R^m)^\ast(Rm)∗ | mm 表示输入数据的一个“宽度”,如我们输入的是一个图片时,宽度可以是一个定值。 \ast∗ 表示这个串的长度不定 \in [0, +\infty)∈[0,+∞),如输入图片的长度是未知长度。 |
\chiχ | \chi = (R^m)^\astχ=(Rm)∗ 表示输入数据空间。 |
ZZ | Z = L^\astZ=L∗ 表示由字母表排列而成的标签集合,我们可以理解成单词表。 |
D_{\chi \times Z}Dχ×Z | 真实数据空间。 |
z | z = (z_1, z_2, \cdots , z_U)=(z1,z2,⋯,zU) 是 ZZ 中的一个样本。 |
x | x = (x_1, x_2, \cdots , x_T), U \leq T=(x1,x2,⋯,xT),U≤T,表示一个样本输入数据的输入序列。如一个定高图片每一列像素可以认为是一个 x_ixi。 |
SS | S \subset D_{\chi \times Z}S⊂Dχ×Z 训练样本集,这个集合中的每个样本都是一个 (x, z) 组合 |
S'S′ | S' \subset D_{\chi \times Z}, S' \bigcap S = \emptysetS′⊂Dχ×Z,S′⋂S=∅ |
hh | 时序分类器。 |
由上面的定义,我们可以看出,因为输入和输出和长度未必相等,所以没有办法事先把这两种数据对齐。
目标
时序分类的目标就是学习h: \chi \longmapsto Zh:χ⟼Z
损失函数
用于 CTC 的损失函数是 Lebal Error Rate(LER)
。 这里我们需要知道“最小编辑距离(Edit Distance, ED)”这个概念,在 CTC 的损失函数就用到了。
在学习算法的时候,
ED
算是一个比较经典的动态规划问题,但在实际工作中其实很少用到这类算法。 所以第一次知道这个算法能用在这里,我还是挺开心的。
损失函数定义如下:LER(h, S') = \frac{1}{|S'|}\sum_{(x,z) \in S'}\frac{ED(h(x), z)}{|z|}LER(h,S′)=∣S′∣1(x,z)∈S′∑∣z∣ED(h(x),z)用编辑距离(ED)来衡量文字串的预测情况还是一件蛮符合直观理解的事情。
连接的时序分类(Connectionist Temporal Classification)
写了半天终于到正题了,下面开始讲 CTC! CTC 网络的 softmax
输出层输出的类别有 |L| + 1∣L∣+1 种,因为有一个分隔符,比如说是空格。
这个分隔符其实蛮重要的,它可以很好地区分一个输出序列串中,哪些子串是属于同一个文字的图片区域的输出结果。
一个输出序列的概率
首先,我们来看一个实例:输入 x
的长度是 T
,每一个帧的维度是 m
。模型的输出的长度也是 T
,每一帧的维度是 n
。其中 m
n
可以相同也可以不同。用数学定义我们的这个模型就是:
y = N_{w}(x), N_w: (R^m)^T \longmapsto (R^n)^Ty=Nw(x),Nw:(Rm)T⟼(Rn)T这里,我们引入几个新的概念:
y_{k}^tykt 表示,在时间帧为
t
的时候,模型的第k
个输出值。 我们可以理解 y_{k}^tykt 为,模型认为这次输入的 x_txt 被认为是字母表 L'L′ 中第k
个字母的概率。\piπ 表示一个输出序列的组合,如 y_2^1 y_{20}^2 y_1^3 y_5^4 \cdotsy21y202y13y54⋯ 那么每组输入,对应的输出都有 (R^n)^T(Rn)T 种可能的字母排列,我们用 \piπ 表示其中一种排列。论文中称这种排列为
path
。p(\pi | x)p(π∣x) 一种输出组合的概率。公式如下:p(\pi | x) = \prod y_{\pi_t}^t, \forall_{t = 1}^{T}\pi \in {L'}^T.p(π∣x)=∏yπtt,∀t=1Tπ∈L′T.
如, 的一种可能的输出 hheelloo
的概率可以表示为 p('hheelloo' | hello.png) = y_8^1\ast y_8^2 \ast y_{5}^3 \ast y_{5}^4 \ast y_{12}^5 \ast y_{12}^6 \ast y_{15}^7 \ast y_{15}^8p(′hheelloo′∣hello.png)=y81∗y82∗y53∗y54∗y125∗y126∗y157∗y158
一种输出序列的规整
同一种输入对应的多种输出可能会有多种形式。 如 的输出可能是 hheel-loo
也可能是 hh-ee-l-l-oo
等。这里的 -
表示空格。 原论文处理这种情况的规则非常简单 We do this by simply removing all blanks and repeated labels from the paths
:B(a-ab-) = B(-aa--abb) = aabB(a−ab−)=B(−aa−−abb)=aab一句话:把空格和连续重复的字母去掉。 那么B(hheel-loo) = helloB(hheel−loo)=hello
B(hh-ee-l-l-oo) = helloB(hh−ee−l−l−oo)=hello
则模型预测标签为:l = B(\pi), |l| \leq Tl=B(π),∣l∣≤T
预测标签的概率
我们可以看到预测标签有很多备选的输出序列,所以预测标签 ll 的概率公式:p(l|x) = \sum_{\pi \in B^{-1}(l)}p(\pi|x).p(l∣x)=π∈B−1(l)∑p(π∣x).其中,B^{-1}(l)B−1(l) 是输出序列规整函数 B(\pi)B(π) 的反函数。 如,p(l = 'hello'| x) = p(\pi = 'hheel-loo'|x) + p(\pi = 'hh-ee-l-l-oo' | x) + \cdotsp(l=′hello′∣x)=p(π=′hheel−loo′∣x)+p(π=′hh−ee−l−l−oo′∣x)+⋯
讲到这里,大家就应该明白 CTC 是怎么工作的了。当然还有很多为了实现而做的工作,有时间再接着写吧。