在做OCR时用到了CTC Loss,对CTC Loss一直都是只有宏观的概念,并没有认真研究它的细节原理(主要是没勇气研究),最近由于需要修改CTC中的解码部分,所以又硬着头皮看论文,查资料,经过不懈努力,总算是明白了一点。接下来我将按照我的理解对CTC的损失计算、解码进行详细说明,限于本人水平有限,不对之处,敬请指正。
CTC出现的背景
在序列学习任务中,RNN对训练样本一般有这样的依赖条件:输入序列和输出序列之间的映射关系已经事先标注好了,可以根据输出序列和标注样本间的差异来直接定义RNN模型的Loss函数。比如,在词性标注任务中,训练样本中每个词(或短语)对应的词性会事先标注好。
但是,在OCR、语音识别时,由于我们很难对样本的输入进行标注(很难区分相邻信息间的分界线),所以仅使用RNN是很难解决这些问题的。这时Alex Graves等人在ICML 2006上提出的一种端到端的RNN训练方法Connectionist Temporal Classification(CTC),它可以让RNN直接对序列数据进行学习,而无需事先标注好训练数据中输入序列和输入序列的映射关系,使得RNN模型在语音识别等序列学习任务中取得更好的效果,在语音识别和图像识别等领域CTC算法都有很比较广泛的应用。
CTC介绍
假设RNN(一般都是经过了
s
o
f
t
m
a
x
softmax
softmax的)的某一条输出为
π
=
{
π
1
,
π
2
,
.
.
π
n
}
\pi=\{\pi_1,\pi_2,..\pi_n\}
π={π1,π2,..πn},对应的标签为
I
I
I,m <= n,CTC的目的就是为了将
π
\pi
π通过一个函数B映射成
I
I
I,即:
I
=
B
(
π
)
I=B(\pi)
I=B(π)。
y
π
t
t
y_{\pi_t}^{t}
yπtt表示在
t
t
t时刻输出为
π
t
\pi_t
πt的概率。
假设每个输出之间都是相互独立的,则其中一条符合
l
=
B
(
π
)
l=B(\pi)
l=B(π)的路径的概率为:
p
(
π
∣
x
)
=
∏
t
=
1
T
y
π
t
t
p(\pi|x)=\prod_{t=1}^{T}y_{\pi_t}^{t}
p(π∣x)=∏t=1Tyπtt
够映射成
I
I
I的概率为:
p
(
I
∣
x
)
=
∑
π
∈
B
−
1
(
I
)
p
(
π
∣
x
)
p(I|x)=\sum_{\pi\in B^{-1}(I)}p(\pi|x)
p(I∣x)=∑π∈B−1(I)p(π∣x),这里的
π
∈
B
−
1
(
I
)
\pi\in B^{-1}(I)
π∈B−1(I)是指所有能够映射成
l
l
l的
π
\pi
π
由于直接暴力计算
p
(
I
∣
x
)
p(I|x)
p(I∣x)的复杂度非常高,作者借鉴HMM的Forward-Backward算法思路,利用动态规划算法求解。
在正式介绍前向,后向算法之前,我们先说明一些条件,方便后续的理解。
- 将目标序列 I I I转化为label,在目标序列的首尾和中间都加上空格,用 l ′ l^{'} l′表示。如上图所示:我们的目标序列是:CAT,将CAT的首尾和中间都添加空格(blank),变成了:-C-A-T-,图中白色代表实体,黑色代表空格。
- 路径的搜索只能从左上方往右下方进行,不能低于当前位置
- 相同字符之间至少需要一个空格。比如:序列aa之间至少有一个“-”,否则就是错误的,因为不包含"-"的会被合并成一个a
- 非空字符不能被跳过。搜索过程中非空字符必须要对应一个输出
- 起点必须从第一个(空白)或第二个(第一个非空字符)开始,终点必须在最后一个(空白)或第二个(最后一个非空字符)结束。
前向算法
这里用
α
(
t
,
u
)
=
∑
π
∈
V
(
s
,
u
)
∏
i
=
1
t
y
π
i
i
\alpha(t,u)=\sum_{\pi\in V(s,u)}\prod_{i=1}^{t} y_{\pi_i}^{i}
α(t,u)=∑π∈V(s,u)∏i=1tyπii表示
t
t
t时刻经过节点
u
u
u的路径的概率总和(
u
u
u是
l
′
l^{'}
l′的索引,从1开始),特别的当
t
=
1
t=1
t=1时:
α
(
1
,
1
)
=
y
b
1
α
(
1
,
2
)
=
y
l
1
α
(
1
,
u
)
=
0
,
u
>
2
\begin{aligned} & \alpha(1, 1)=y_b^1 \\ & \alpha(1, 2)=y_{l_1} \\ & \alpha(1, u)=0,\space\space u\gt2 \end{aligned}
α(1,1)=yb1α(1,2)=yl1α(1,u)=0, u>2
其他时刻需要分情况考虑:
- t t t时刻经过的结点 ( u , t ) (u, t) (u,t)为空白时,那么能够到达它的节点为 ( u , t − 1 ) (u, t-1) (u,t−1)、 ( u − 1 , t − 1 ) (u-1, t-1) (u−1,t−1),可以表达为: α ( t , u ) = ( α ( t − 1 , u ) + α ( t − 1 , u − 1 ) ) ∗ y u t \alpha(t,u)=(\alpha(t-1,u)+\alpha(t-1,u-1))*y_{u}^t α(t,u)=(α(t−1,u)+α(t−1,u−1))∗yut;
- t t t时刻经过的结点 ( u , t ) (u, t) (u,t)为非空字符且与前一个非空字符相同时,那么能够到达它的节点为 ( u , t − 1 ) (u, t-1) (u,t−1)、 ( u − 1 , t − 1 ) (u-1, t-1) (u−1,t−1),可以表达为: α ( t , u ) = ( α ( t − 1 , u ) + α ( t − 1 , u − 1 ) ) ∗ y u t \alpha(t,u)=(\alpha(t-1,u)+\alpha(t-1,u-1))*y_{u}^t α(t,u)=(α(t−1,u)+α(t−1,u−1))∗yut,与1一样;
-
t
t
t时刻经过的结点
(
u
,
t
)
(u, t)
(u,t)为非空字符且与前一个非空字符不相同时,那么能够到达它的节点为
(
u
,
t
−
1
)
(u, t-1)
(u,t−1)、
(
u
−
1
,
t
−
1
)
(u-1, t-1)
(u−1,t−1)、
(
u
−
2
,
t
−
1
)
(u-2,t-1)
(u−2,t−1),可以表达为:
α
(
t
,
u
)
=
(
α
(
t
−
1
,
u
)
+
α
(
t
−
1
,
u
−
1
)
+
α
(
t
−
1
,
u
−
2
)
)
∗
y
u
t
\alpha(t,u)=(\alpha(t-1,u)+\alpha(t-1,u-1)+\alpha(t-1,u-2))*y_{u}^t
α(t,u)=(α(t−1,u)+α(t−1,u−1)+α(t−1,u−2))∗yut。
论文中用一下表述该概括这三种情况:
α ( t , u ) = y u t ∑ i = f ( u ) u α ( t − 1 , i ) \begin{aligned} & \alpha(t,u)=y_u^t \sum_{i=f(u)}^{u}\alpha(t-1,i) \\ \end{aligned} α(t,u)=yuti=f(u)∑uα(t−1,i)
其中: f ( u ) = { u − 1 l ′ [ u ] = b l a n k o r l ′ [ u ] = l ′ [ u − 2 ] u − 2 otherwise f(u)= \begin{cases} u-1& \text{$l^{'}[u]=blank \space or \space l^{'}[u]=l^{'}[u-2]$}\\ u-2& \text{otherwise} \end{cases} f(u)={u−1u−2l′[u]=blank or l′[u]=l′[u−2]otherwise
最后,总的损失(考虑最后一个是空格和非空两种情况, , ∣ l ′ ∣ ,|l^{'}| ,∣l′∣表示label的长度)可以表示为:
L ( S ) = − l n ( I ∣ x ) = − l n ( α ( T , ∣ l ′ ∣ ) + α ( T , ∣ l ′ ∣ − 1 ) ) L(S)=-ln(I|x)=-ln(\alpha(T,|l^{'}|)+\alpha(T,|l^{'}|-1)) L(S)=−ln(I∣x)=−ln(α(T,∣l′∣)+α(T,∣l′∣−1))
后向算法
后向算法与前向算法一样,只是方向是反的。前向算法是从
t
=
1
t=1
t=1到
t
=
T
t=T
t=T,后向算法是
t
=
T
t=T
t=T到
t
=
1
t=1
t=1。
这里用
β
(
t
,
u
)
=
∑
π
∈
V
(
s
,
u
)
∏
i
=
t
+
1
T
y
π
i
i
\beta(t,u)=\sum_{\pi\in V(s,u)}\prod_{i=t+1}^{T} y_{\pi_i}^{i}
β(t,u)=∑π∈V(s,u)∏i=t+1Tyπii表示
t
t
t时刻经过节点
u
u
u的路径的概率总和(
u
u
u是
l
′
l^{'}
l′的索引,从1开始),特别的当
t
=
T
t=T
t=T时:
β
(
T
,
∣
l
′
∣
)
=
1
β
(
T
,
∣
l
′
∣
−
1
)
=
1
β
(
T
,
u
)
=
0
,
u
<
∣
l
′
∣
−
2
\begin{aligned} & \beta(T, |l^{'}|)=1 \\ & \beta(T, |l^{'}|-1)=1 \\ & \beta(T, u)=0,\space\space u\lt|l^{'}|-2 \end{aligned}
β(T,∣l′∣)=1β(T,∣l′∣−1)=1β(T,u)=0, u<∣l′∣−2
其他时刻需要分情况考虑:
- t t t时刻经过的结点 ( u , t ) (u, t) (u,t)为空白时,那么能够到达它的节点为 ( u , t + 1 ) (u, t+1) (u,t+1)、 ( u + 1 , t + 1 ) (u+1, t+1) (u+1,t+1),可以表达为: β ( t , u ) = ( β ( t + 1 , u ) + β ( t + 1 , u + 1 ) ) ∗ y u t + 1 \beta(t,u)=(\beta(t+1,u)+\beta(t+1,u+1))*y_{u}^{t+1} β(t,u)=(β(t+1,u)+β(t+1,u+1))∗yut+1;
- t t t时刻经过的结点 ( u , t ) (u, t) (u,t)为非空字符且与前一个非空字符相同时,那么能够到达它的节点为 ( u , t + 1 ) (u, t+1) (u,t+1)、 ( u + 1 , t + 1 ) (u+1, t+1) (u+1,t+1),可以表达为: β ( t , u ) = ( β ( t + 1 , u ) + β ( t + 1 , u + 1 ) ) ∗ y u t + 1 \beta(t,u)=(\beta(t+1,u)+\beta(t+1,u+1))*y_{u}^{t+1} β(t,u)=(β(t+1,u)+β(t+1,u+1))∗yut+1,与1一样;
-
t
t
t时刻经过的结点
(
u
,
t
)
(u, t)
(u,t)为非空字符且与前一个非空字符不相同时,那么能够到达它的节点为
(
u
,
t
+
1
)
(u, t+1)
(u,t+1)、
(
u
+
1
,
t
+
1
)
(u+1, t+1)
(u+1,t+1)、
(
u
+
2
,
t
+
1
)
(u+2,t+1)
(u+2,t+1),可以表达为:
β
(
t
,
u
)
=
(
β
(
t
+
1
,
u
)
+
β
(
t
+
1
,
u
+
1
)
+
β
(
t
+
1
,
u
+
2
)
)
∗
y
u
t
+
1
\beta(t,u)=(\beta(t+1,u)+\beta(t+1,u+1)+\beta(t+1,u+2))*y_{u}^{t+1}
β(t,u)=(β(t+1,u)+β(t+1,u+1)+β(t+1,u+2))∗yut+1。
论文中用一下表述该概括这三种情况:
β ( t , u ) = y u t + 1 ∑ i = f ( u ) u β ( t + 1 , i ) \begin{aligned} & \beta(t,u)=y_{u}^{t+1} \sum_{i=f(u)}^{u}\beta(t+1,i) \\ \end{aligned} β(t,u)=yut+1i=f(u)∑uβ(t+1,i)
其中: f ( u ) = { u + 1 l ′ [ u ] = b l a n k o r l ′ [ u ] = l ′ [ u + 2 ] u + 2 otherwise f(u)= \begin{cases} u+1& \text{$l^{'}[u]=blank \space or \space l^{'}[u]=l^{'}[u+2]$}\\ u+2& \text{otherwise} \end{cases} f(u)={u+1u+2l′[u]=blank or l′[u]=l′[u+2]otherwise
最后,总的损失(考虑最后一个是空格和非空两种情况, , ∣ l ′ ∣ ,|l^{'}| ,∣l′∣表示label的长度)可以表示为:
L ( S ) = − l n ( I ∣ x ) = − l n ( β ( 1 , 1 ) + β ( 1 , 2 ) ) L(S)=-ln(I|x)=-ln(\beta(1,1)+\beta(1,2)) L(S)=−ln(I∣x)=−ln(β(1,1)+β(1,2))
损失函数
这里我们可以利用前向算法和后向算法来表示
t
t
t时刻通过节点
u
u
u的概率:
α
(
t
,
u
)
β
(
t
,
u
)
=
∑
π
∈
X
(
t
,
u
)
∏
t
=
1
T
y
π
t
t
=
∑
π
∈
X
(
t
,
u
)
p
(
π
∣
x
)
\begin{aligned} \alpha(t,u)\beta(t,u) & =\sum_{\pi \in X(t, u)}\prod_{t=1}^{T}y_{\pi_t}^{t} \\ & =\sum_{\pi \in X(t, u)}p(\pi|x) \\ \end{aligned}
α(t,u)β(t,u)=π∈X(t,u)∑t=1∏Tyπtt=π∈X(t,u)∑p(π∣x)
之前我们只是表示了总的损失,那么
t
t
t时刻的损失如何表示了,论文中做了表述,
t
t
t时刻的损失可以表示为:
KaTeX parse error: Undefined control sequence: \inX at position 58: …=-\ln(\sum_{\pi\̲i̲n̲X̲(t,u)}\prod_{t=…
梯度反向传播
如上图所示,我们要求损失关于
u
k
t
u_k^t
ukt的梯度。在求解之前,我们需要先做一些准备工作(用
z
z
z表示
I
I
I,论文是这样表达的):
∂
L
(
z
,
x
)
∂
y
k
′
t
=
∂
(
−
ln
p
(
z
∣
x
)
)
∂
y
k
′
t
=
−
1
p
(
z
∣
x
)
∂
∑
π
∈
B
−
1
(
z
)
p
(
π
∣
x
)
∂
y
k
′
t
=
−
1
p
(
z
∣
x
)
(
∑
u
∈
B
(
z
,
k
′
)
∂
α
(
t
,
u
)
β
(
t
,
u
)
∂
y
k
′
t
)
=
−
1
p
(
z
∣
x
)
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
y
k
′
t
\begin{aligned} \frac{\partial L(z,x)}{\partial y_{k^{'}}^t} & = \frac{\partial(-\ln p(z|x))}{\partial y_{k^{'}}^t} \\ & =-\frac{1}{p(z|x)} \frac{\partial \sum_{\pi \in B^{-1}(z)}p(\pi|x)}{\partial y_{k^{'}}^t} \\ & =-\frac{1}{p(z|x)} (\sum_{u \in B(z,k^{'})}\frac{\partial \alpha(t,u)\beta(t,u)}{\partial y_{k^{'}}^{t}}) \\ & =-\frac{1}{p(z|x)} \sum_{u \in B(z,k^{'})} \frac{\alpha(t,u)\beta(t,u)}{y_{k^{'}}^{t}} \\ \end{aligned}
∂yk′t∂L(z,x)=∂yk′t∂(−lnp(z∣x))=−p(z∣x)1∂yk′t∂∑π∈B−1(z)p(π∣x)=−p(z∣x)1(u∈B(z,k′)∑∂yk′t∂α(t,u)β(t,u))=−p(z∣x)1u∈B(z,k′)∑yk′tα(t,u)β(t,u)
∂
y
k
′
t
∂
u
k
t
=
y
k
′
t
(
δ
k
k
′
−
y
k
t
)
\begin{aligned} \frac{\partial y_{k^{'}}^t}{\partial u_k^t}=y_{k^{'}}^t(\delta_{kk^{'}}-y_k^t) \end{aligned}
∂ukt∂yk′t=yk′t(δkk′−ykt)
这个是softmax的求导,具体过程这里就不累赘了。
其中:
δ
k
k
′
=
{
1
k
=
k
′
0
otherwise
\delta_{kk^{'}}= \begin{cases} 1& \text{$k=k^{'}$}\\ 0& \text{otherwise} \end{cases}
δkk′={10k=k′otherwise
如上图所示,如果我们要求损失关于
u
k
t
u_k^t
ukt的梯度,则:
∂
L
(
z
,
x
)
∂
u
k
t
=
∑
k
′
∂
L
(
z
,
x
)
∂
y
k
′
t
∂
y
k
′
t
∂
u
k
t
=
−
∑
k
′
1
p
(
z
∣
x
)
(
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
y
k
′
t
)
y
k
′
t
(
δ
k
k
′
−
y
k
t
)
=
−
1
p
(
z
∣
x
)
∑
k
′
(
(
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
y
k
′
t
)
y
k
′
t
(
δ
k
k
′
−
y
k
t
)
)
=
−
1
p
(
z
∣
x
)
(
∑
k
′
=
k
(
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
y
k
′
t
)
y
k
′
t
(
1
−
y
k
t
)
+
∑
k
′
!
=
k
(
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
y
k
′
t
)
y
k
′
t
(
0
−
y
k
t
)
)
=
−
1
p
(
z
∣
x
)
(
∑
k
′
=
k
(
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
)
(
1
−
y
k
t
)
−
∑
k
′
!
=
k
(
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
)
y
k
t
)
=
−
1
p
(
z
∣
x
)
(
∑
k
′
=
k
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
−
∑
k
′
=
k
(
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
)
y
k
t
)
−
∑
k
′
!
=
k
(
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
)
y
k
t
)
=
−
1
p
(
z
∣
x
)
(
∑
u
∈
B
(
z
,
k
)
α
(
t
,
u
)
β
(
t
,
u
)
−
∑
k
′
(
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
)
y
k
t
)
=
1
p
(
z
∣
x
)
(
−
∑
u
∈
B
(
z
,
k
)
α
(
t
,
u
)
β
(
t
,
u
)
+
y
k
t
∑
u
∈
B
(
z
,
k
′
)
α
(
t
,
u
)
β
(
t
,
u
)
)
=
1
p
(
z
∣
x
)
(
−
∑
u
∈
B
(
z
,
k
)
α
(
t
,
u
)
β
(
t
,
u
)
+
y
k
t
p
(
z
∣
x
)
)
=
y
k
t
−
1
p
(
z
∣
x
)
∑
u
∈
B
(
z
,
k
)
α
(
t
,
u
)
β
(
t
,
u
)
\begin{aligned} \frac{\partial L(z,x)}{\partial u_k^t} & = \sum_{k^{'}} \frac{\partial L(z,x)}{\partial y_{k^{'}}^t}\frac{\partial y_{k^{'}}^t}{\partial u_k^t} \\ & = -\sum_{k^{'}}\frac{1}{p(z|x)} (\sum_{u \in B(z,k^{'})} \frac{\alpha(t,u)\beta(t,u)}{y_{k^{'}}^{t}}) y_{k^{'}}^t(\delta_{kk^{'}}-y_k^t) \\ & = -\frac{1}{p(z|x)}\sum_{k^{'}}((\sum_{u \in B(z,k^{'})} \frac{\alpha(t,u)\beta(t,u)}{y_{k^{'}}^{t}}) y_{k^{'}}^t(\delta_{kk^{'}}-y_k^t)) \\ & = -\frac{1}{p(z|x)}(\sum_{k^{'}=k}(\sum_{u \in B(z,k^{'})} \frac{\alpha(t,u)\beta(t,u)}{y_{k^{'}}^{t}}) y_{k^{'}}^t(1-y_k^t) +\sum_{k^{'}!=k}(\sum_{u \in B(z,k^{'})} \frac{\alpha(t,u)\beta(t,u)}{y_{k^{'}}^{t}}) y_{k^{'}}^t(0-y_k^t)) \\ & = -\frac{1}{p(z|x)}(\sum_{k^{'}=k}(\sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u))(1-y_k^t) -\sum_{k^{'}!=k}(\sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u))y_k^t) \\ & = -\frac{1}{p(z|x)}(\sum_{k^{'}=k} \sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u) - \sum_{k^{'}=k}(\sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u))y_k^t) - \sum_{k^{'}!=k}(\sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u))y_k^t) \\ & = -\frac{1}{p(z|x)}( \sum_{u \in B(z,k)} \alpha(t,u)\beta(t,u) - \sum_{k^{'}}(\sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u))y_k^t) \\ & = \frac{1}{p(z|x)}(-\sum_{u \in B(z,k)} \alpha(t,u)\beta(t,u) + y_k^t \sum_{u \in B(z,k^{'})} \alpha(t,u)\beta(t,u)) \\ & = \frac{1}{p(z|x)}(-\sum_{u \in B(z,k)} \alpha(t,u)\beta(t,u) + y_k^t p(z|x)) \\ & = y_k^t - \frac{1}{p(z|x)}\sum_{u \in B(z,k)} \alpha(t,u)\beta(t,u) \\ \end{aligned}
∂ukt∂L(z,x)=k′∑∂yk′t∂L(z,x)∂ukt∂yk′t=−k′∑p(z∣x)1(u∈B(z,k′)∑yk′tα(t,u)β(t,u))yk′t(δkk′−ykt)=−p(z∣x)1k′∑((u∈B(z,k′)∑yk′tα(t,u)β(t,u))yk′t(δkk′−ykt))=−p(z∣x)1(k′=k∑(u∈B(z,k′)∑yk′tα(t,u)β(t,u))yk′t(1−ykt)+k′!=k∑(u∈B(z,k′)∑yk′tα(t,u)β(t,u))yk′t(0−ykt))=−p(z∣x)1(k′=k∑(u∈B(z,k′)∑α(t,u)β(t,u))(1−ykt)−k′!=k∑(u∈B(z,k′)∑α(t,u)β(t,u))ykt)=−p(z∣x)1(k′=k∑u∈B(z,k′)∑α(t,u)β(t,u)−k′=k∑(u∈B(z,k′)∑α(t,u)β(t,u))ykt)−k′!=k∑(u∈B(z,k′)∑α(t,u)β(t,u))ykt)=−p(z∣x)1(u∈B(z,k)∑α(t,u)β(t,u)−k′∑(u∈B(z,k′)∑α(t,u)β(t,u))ykt)=p(z∣x)1(−u∈B(z,k)∑α(t,u)β(t,u)+yktu∈B(z,k′)∑α(t,u)β(t,u))=p(z∣x)1(−u∈B(z,k)∑α(t,u)β(t,u)+yktp(z∣x))=ykt−p(z∣x)1u∈B(z,k)∑α(t,u)β(t,u)
至此,CTC的反向传播推导完成。CTC的解码过程后续会给出,并且附带原生Python代码实现,近期推出!