4.OCR文本识别Connectionist Temporal Classification(CTC)算法


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


1.基础介绍

论文:Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks

这是2006年第23次ICML会以上的一篇论文。

很多实际应用需要从未切分的数据中输出序列信息,如语音识别中的语音转文字,光学字符识别(Optical character recognition,OCR)中的字符图片转字符序列。循环神经网络(Recurrent neural networks,RNN)十分适合序列数据的学习,但其训练数据要求必须是切分后的序列,而实际应用中切分的训练序列数据标注比较困难,是很难获取的。

在这里插入图片描述

上图是OCR的两种模型,一种如图(a)可直接输入OCR检测得到的图片得到图片中的字符串can,另外一种需要先将图片按字符进行切割,这种方式比较数据处理比较复杂,而这种正是循环神经网络RNN要求的输入。

为了充分利用循环神经网络RNN处理序列数据的能力,同时避免对输入序列图像进行切分,本文作者提出了Connectionist Temporal Classificatio(CTC)算法

2.Connectionist Temporal Classification(CTC)算法

2.1 什么是Temporal Classification

S S S是从分布 D X × Z \mathcal{D}_{\mathcal{X}\times\mathcal{Z}} DX×Z从获取的训练数据,
输入空间 X = ( R m ) ∗ \mathcal{X}=(\mathbb{R}^m)^* X=(Rm) m m m维的实值向量序列,目标空间 Z = L ∗ \mathcal{Z}=L^* Z=L由字母集 L L L组成的标签序列,训练数据集 S S S中的每个样本由序列对 ( x , z ) (\mathbf{x},\mathbf{z}) (x,z)组成。目标序列 z = ( z 1 , z 2 , . . . , z U ) \mathbf{z}=(z_1,z_2,...,z_U) z=(z1,z2,...,zU)长度小于等于输入序列 x = ( x 1 , x 2 , . . . , x T ) , i . e . U ≤ T \mathbf{x}=(x_1,x_2,...,x_T),i.e.U\le T x=(x1,x2,...,xT),i.e.UT。输入序列和输出序列长度一般不同,因此没有先验知识可以对齐他们。

Temporal Classification的任务是使用训练数据 S S S,学习一个分类器,能够将输入序列分成对应的目标序列 h : X → Z h:\mathcal{X}\rightarrow\mathcal{Z} h:XZ

从第一部分介绍,可以知道OCR任务本身就是一个Temporal Classification,翻译成了时间序列分类问题。其输入是卷积后得到特征图序列,输出的是字符序列。

之所以被称为Connectionist Temporal Classification,是这样理解的,原始输入的是一整张联结在一起未切分的字符图像,输出的是字符序列,因为没有对原始图像上的字符进行切分预处理,因此被称之为连接序列分类。

2.2 CTC问题描述

在这里插入图片描述

从网络输入到获取标签序列要分成两步:

第一步,可以将输入为长度为 T T T的序列 x = [ x 1 , x 2 , . . . , x T ] \mathbf{x}=[x_1,x_2,...,x_T] x=[x1,x2,...,xT](序列中每个 x x x都是m维),输出为长度 T T T的序列 y = [ y 1 , y 2 , . . . , y T ] \mathbf{y}=[y_1,y_2,...,y_T] y=[y1,y2,...,yT](序列中每个 y y y都是n维),参数为 w w w的映射(即循环神经网络)定以为 N w : ( R m ) T → ( R n ) T \mathcal{N}_w:(\mathbb{R}^m)^T\rightarrow(\mathbb{R}^n)^T Nw:(Rm)T(Rn)T y = N w ( x ) \mathbf{y}=\mathcal{N}_w(\mathbf{x}) y=Nw(x)。将 y k t y_k^t ykt表示成第 t t t个序列值为 k k k的概率, L ′ T L'^T LT表示长度为 T T T的序列,其中每个元素取自字母集 L ′ = L ∪ { b l a n k } L'=L\cup\{blank\} L=L{blank},序列 L ′ T L'^T LT也被称之为路径,表示成 π \pi π

根据以上定义给定输入 x \mathbf{x} x,输出为路径 π \pi π的概率可表示成:

p ( π ∣ x ) = ∏ t = 1 T y π t t , ∀ π ∈ L ′ T p(\pi|\mathbf{x})=\prod_{t=1}^{T}y_{\pi_t}^t,\forall\pi\in L'^T p(πx)=t=1Tyπtt,πLT

其实,这里还有个条件,就是每一步输出之间是相互独立,上面的公式才能成立。

第二步,我们知道输入 x \mathbf{x} x对应的标签序列为长度等于 U U U的序列 z = [ z 1 , z 2 , . . . , z U ] , U ≤ T \mathbf{z}=[z_1,z_2,...,z_U],U\le T z=[z1,z2,...,zU]UT,在第一步中循环神经网络给出的只是长度为 T T T的中间序列 y \mathbf{y} y,要和长度为 U U U的标签序列 z \mathbf{z} z对应,还需要定义个从中间序列到标签序列的映射 B : L ′ T ↦ L < T \mathcal{B}:L'^T\mapsto L^{\lt T} B:LTL<T,很明显, B \mathcal{B} B是一个多对一的映射。这个映射可以定义为移除中间序列中的重复相邻字符和空格占位符,如 B ( s − t a − t t − e ) = B ( s − t − a a − t t − e ) = s t a t e \mathcal{B}(s-ta-tt-e)=\mathcal{B}(s-t-aa-tt-e)=state B(statte)=B(staatte)=state,定义了映射 B \mathcal{B} B后,可以将输出标签序列 z \mathbf{z} z的后验概率表示成:

p ( z ∣ x ) = ∑ π ∈ B − 1 ( z ) p ( π ∣ x ) p(\mathbf{z}|\mathbf{x})=\sum_{\pi\in\mathcal{B}^{-1}(\mathbf{z})}p(\pi|\mathbf{x}) p(zx)=πB1(z)p(πx)

2.2关于对齐

为什么要使用上述的方法来进行网络的训练呢?那是因为输入 x = [ x 1 , x 2 , . . . , x m ] \mathbf{x}=[x_1,x_2,...,x_m] x=[x1,x2,...,xm]和标签序列 z = [ z 1 , z 2 , . . . , z U ] \mathbf{z}=[z_1,z_2,...,z_U] z=[z1,z2,...,zU]之间在序列长度,序列长度比例,对应元素之间找不到什么对应关系。

在这里插入图片描述

如上图是对齐后的数据,但在实际中是很难知道 ( x 1 , x 2 ) ↦ c , ( x 3 , x 4 , x 5 ) ↦ a , ( x 6 ) ↦ t (x_1,x_2)\mapsto c,(x_3,x_4,x_5)\mapsto a,(x_6)\mapsto t (x1,x2)c,(x3,x4,x5)a,(x6)t,标注这样的数据也需要花费大量的时间,因此更希望模型能够拥有从未对齐数据中学习的能力,通过前面的介绍,使用CTC算法可以从未对齐的输入中求得标签序列。

2.3 前向后向算法

在这里插入图片描述

使用暴力方法计算

p ( z ∣ x ) = ∑ π ∈ B − 1 ( z ) p ( π ∣ x ) p(\mathbf{z}|\mathbf{x})=\sum_{\pi\in\mathcal{B}^{-1}(\mathbf{z})}p(\pi|\mathbf{x}) p(zx)=πB1(z)p(πx)

因为要计算每一条路径,因此对于序列字典中有 n n n个元素,长度为 T T T的序列,要计算所有路径的概率,时间复杂度为 O ( n T ) O(n^T) O(nT),这是指数级的时间复杂度,对于大部分长度的序列这个运算都过于耗时。论文作者为了解决这个问题,提出了前向后向递推算法,采用动态规划的方法将时间复杂度降到了 O ( n T ) O(nT) O(nT),使算法更可行。

先借个例子来看一下。

假设标签序列为
z = s t a t e \mathbf{z} = state z=state

在序列前后和每个字符中间添加空格占位符 − -

z ′ = − s − t − a − t − e − \mathbf{z}'=-s-t-a-t-e- z=state

z ′ \mathbf{z}' z中任意的字符重复任意次,经过 B \mathcal{B} B映射都能得到标签序列 s t a t e state state,因此可以将 z ′ \mathbf{z}' z当成满足变换条件的基础序列。 B \mathcal{B} B是多对一的映射,如下4个路径都能得到 s t a t e state state

B ( − − s t t a a − t e e − ) = s t a t e B ( − − s t t a − t − − − e ) = s t a t e B ( s s t − a a a − t e e − ) = s t a t e B ( s s t − a a − t − − − e ) = s t a t e \mathcal{B}(--sttaa-tee-)=state\\ \mathcal{B}(--stta-t---e)=state\\ \mathcal{B}(sst-aaa-tee-)=state\\ \mathcal{B}(sst-aa-t---e)=state B(sttaatee)=stateB(sttate)=stateB(sstaaatee)=stateB(sstaate)=state

z ′ \mathbf{z}' z写成列的形式,则上述四条路径可以写成如下图的形式:

在这里插入图片描述

从上图可以看到,四条路径在序列 t = 6 t=6 t=6时都经过字符 a a a,记上面的四条路径为 π 1 , π 2 , π 3 , π 4 \pi^1,\pi^2,\pi^3,\pi^4 π1,π2,π3,π4

π 1 = b = b 1 : 5 + { a } 6 + b 7 : 12 π 2 = r = r 1 : 5 + { a } 6 + r 7 : 12 π 3 = b 1 : 5 + { a } 6 + r 7 : 12 π 4 = r 1 : 5 + { a } 6 + b 7 : 12 \pi^1=b=b_{1:5}+\{a\}_6+b_{7:12}\\ \pi^2=r=r_{1:5}+\{a\}_6+r_{7:12}\\ \pi^3=b_{1:5}+\{a\}_6+r_{7:12}\\ \pi^4=r_{1:5}+\{a\}_6+b_{7:12} π1=b=b1:5+{a}6+b7:12π2=r=r1:5+{a}6+r7:12π3=b1:5+{a}6+r7:12π4=r1:5+{a}6+b7:12

y k t y_k^t ykt表示序列第 t t t步元素为 k k k的概率,则上面四条路径都包含 y a 6 y_a^6 ya6这一项,将计算上面四条路径的概率表示可以提取公因式写成:

f o w a r d = p ( b 1 : 5 + r 1 : 5 ∣ x ) = y − 1 ∗ y − 2 ∗ y s 3 ∗ y t 4 ∗ y t 5 + y s 1 ∗ y s 2 ∗ y t 3 ∗ y − 4 ∗ y a 5 b a c k w a r d = p ( b 7 : 12 + r 7 : 12 ∣ x ) = y − 7 ∗ y t 8 ∗ y − 9 ∗ y − 10 ∗ y − 11 ∗ y e 12 + y a 7 ∗ y − 8 ∗ y t 9 ∗ y e 10 ∗ y e 11 ∗ y − 12 foward = p(b_{1:5}+r_{1:5}|\mathbf{x}) = y_-^1*y_-^2*y_s^3*y_t^4*y_t^5 + y_s^1*y_s^2*y_t^3*y_-^4*y_a^5\\ backward = p(b_{7:12}+r_{7:12}|\mathbf{x}) = y_-^7*y_t^8*y_-^9*y_-^{10}*y_-^{11}*y_e^{12} + y_a^7*y_-^8*y_t^9*y_e^{10}*y_e^{11}*y_-^{12} foward=p(b1:5+r1:5x)=y1y2ys3yt4yt5+ys1ys2yt3y4ya5backward=p(b7:12+r7:12x)=y7yt8y9y10y11ye12+ya7y8yt9ye10ye11y12

然后上面四条路径的概率和可以写成:

p ( π 1 , π 2 , π 3 , π 4 ∣ x ) = f o r w a r d ∗ y a 6 ∗ b a c k w a r d p(\pi^1,\pi^2,\pi^3,\pi^4|\mathbf{x}) = forward*y_a^6*backward p(π1,π2,π3,π4x)=forwardya6backward

上面的介绍中只取了四条经过变换 B \mathcal{B} B后能得到 s t a t e state state的路径,实际上的路径要远远多于此:

在这里插入图片描述

从上图中选出经过 { a } 6 \{a\}_6 {a}6的所有路径,概率 ∑ B ( π ) = z , π 6 = a p ( π ∣ x ) \sum\limits_{\mathcal{B}(\pi)=\mathbf{z},\pi_6=a}p(\pi|x) B(π)=z,π6=ap(πx)( π 6 = a \pi_6=a π6=a表示路径 π \pi π的第6个字符为a),同样还是可以表示成如下形式:

∑ B ( π ) = z , π 6 = a p ( π ∣ x ) = f o r w a r d ∗ y a 6 ∗ b a c k w a r d \sum\limits_{\mathcal{B}(\pi)=\mathbf{z},\pi_6=a}p(\pi|x)=forward*y_a^6*backward B(π)=z,π6=ap(πx)=forwardya6backward

进一步推广,定义 α t ( s ) \alpha_t(s) αt(s)表示路径 π \pi π中的第t个字符与加了占位符后标签序列 z ′ \mathcal{z}' z的第s个字相对应且路径 π \pi π满足 B ( π 1 : t ) = z 1 : s \mathcal{B}(\pi_{1:t})=\mathbf{z}_{1:s} B(π1:t)=z1:s时所有路径 π 1 : t \pi_{1:t} π1:t的概率和,表示成:

α t ( s ) = ∑ B ( π 1 : t ) = 留 − z 1 : s ′ ∏ t ′ = 1 t y π t ′ t ′ \alpha_t(s)=\sum\limits_{\mathcal{B}(\pi_{1:t})\overset{留-}{=}\mathbf{z}'_{1:s}}\prod_{t'=1}^{t}y^{t'}_{\pi_{t'}} αt(s)=B(π1:t)=z1:st=1tyπtt

可以看到这等同于前向变量 f o r w a r d forward forward,现在来看 t = 1 t=1 t=1时的 α 1 ( s ) \alpha_1(s) α1(s),要经过 B \mathcal{B} B映射后能得到保留占位符的标签序列, s s s就只能等于1或者2,看上图中 − s − t − a − t − e − -s-t-a-t-e- state的例子,t=1时刻只能取 z ′ \mathcal{z}' z − - 或者 s s s,否则无法经过 B \mathcal{B} B映射得到标签序列,因此

α 1 ( 1 ) = y − 1 α 1 ( 2 ) = y z 2 ′ 1 α 1 ( s ) = 0 , ∀ s > 2 \alpha_1(1)=y^1_{-}\\ \alpha_1(2)=y^1_{\mathbf{z}'_2}\\ \alpha_1(s)=0,\forall s\gt2 α1(1)=y1α1(2)=yz21α1(s)=0,s>2

还看 s t a t e state state的例子,当过 z ′ 6 {\mathbf{z}'}_6 z6时, t = 5 t=5 t=5对应的字符只能是 t / − / a t/-/a t//a,可以推出来上面例子中

α 6 ( 6 ) = ( α 5 ( 4 ) + α 5 ( 5 ) + α 5 ( 6 ) ) ∗ y a 6 \alpha_6(6)=(\alpha_5(4)+\alpha_5(5)+\alpha_5(6))*y_a^6 α6(6)=(α5(4)+α5(5)+α5(6))ya6
一般化推广可得:

α t ( s ) = ( α t − 1 ( s − 2 ) + α t − 1 ( s − 1 ) + α t − 1 ( s ) ) ∗ y z s ′ t \alpha_t(s)=(\alpha_{t-1}(s-2)+\alpha_{t-1}(s-1)+\alpha_{t-1}(s))*y_{\mathbf{z}'_s}^{t} αt(s)=(αt1(s2)+αt1(s1)+αt1(s))yzst

还需考虑一个特殊情况,看下面例子 z = z o o , z ′ = − z − o − o − \mathbf{z}=zoo,\mathbf{z}'=-z-o-o- z=zoo,z=zoo,t=2,s=6或3:

在这里插入图片描述

很明显因为 B \mathcal{B} B映射会去除重复的字母,因此上面两种情况在 t − 1 t-1 t1时刻不能取 s − 2 s-2 s2

综上,可得最终 t ≥ 2 t\ge2 t2时前向递推公式为(也就是原论文上的递推公式):

α t ( s ) = { ( α t − 1 ( s − 1 ) + α t − 1 ( s ) ) ∗ y z s ′ t   i f   z s ′ = −   o r   z s ′ = z s − 2 ′ ( α t − 1 ( s − 2 ) + α t − 1 ( s − 1 ) + α t − 1 ( s ) ) ∗ y z s ′ t   o t h e r w i s e \alpha_t(s)=\left\{\begin{matrix} (\alpha_{t-1}(s-1)+\alpha_{t-1}(s))*y_{\mathbf{z}'_s}^{t}\,if\,z'_s=-\,or\,z'_s=z'_{s-2} \\ (\alpha_{t-1}(s-2)+\alpha_{t-1}(s-1)+\alpha_{t-1}(s))*y_{\mathbf{z}'_s}^{t}\,otherwise \end{matrix}\right. αt(s)={(αt1(s1)+αt1(s))yzstifzs=orzs=zs2(αt1(s2)+αt1(s1)+αt1(s))yzstotherwise

将公式中相同的项合并一下就可以得到论文上的公式了。

同样的方法可以定义 b a c k w a r d backward backward:

β t ( s ) = ∑ B ( π t : T ) = 留 − z s : ∣ z ′ ∣ ′ ∏ t ′ = t T y π t ′ t ′ \beta_t(s)=\sum\limits_{\mathcal{B}(\pi_{t:T})\overset{留-}{=}\mathbf{z}'_{s:|z'|}}\prod_{t'=t}^{T}y^{t'}_{\pi_{t'}} βt(s)=B(πt:T)=zs:zt=tTyπtt

t ≥ 2 t\ge2 t2 β t ( s ) \beta_t(s) βt(s)的递推公式:

β t ( s ) = { ( β t + 1 ( s ) + β t + 1 ( s + 1 ) ) ∗ y z s ′ t   i f   z s ′ = −   o r   z s ′ = z s + 2 ′ ( β t + 1 ( s ) + β t + 1 ( s + 1 ) + β t + 1 ( s + 2 ) ) ∗ y z s ′ t   o t h e r w i s e \beta_t(s)=\left\{\begin{matrix} (\beta_{t+1}(s)+\beta_{t+1}(s+1))*y_{\mathbf{z}'_s}^{t}\,if\,z'_s=-\,or\,z'_s=z'_{s+2} \\ (\beta_{t+1}(s)+\beta_{t+1}(s+1)+\beta_{t+1}(s+2))*y_{\mathbf{z}'_s}^{t}\,otherwise \end{matrix}\right. βt(s)={(βt+1(s)+βt+1(s+1))yzstifzs=orzs=zs+2(βt+1(s)+βt+1(s+1)+βt+1(s+2))yzstotherwise

求得 α t ( s ) \alpha_t(s) αt(s) β t ( s ) \beta_t(s) βt(s)后,标签序列 z \mathbf{z} z的后验概率可以写成,

p ( z ∣ x ) = ∑ z s ′ ∈ π t α t ( s ) β t ( s ) y z s ′ t p(\mathbf{z}|\mathbf{x})=\sum_{z'_s\in\pi_t}\frac{\alpha_t(s)\beta_t(s)}{y_{z'_s}^t} p(zx)=zsπtyzstαt(s)βt(s)

求得 p ( z ∣ x ) p(\mathbf{z}|\mathbf{x}) p(zx)后,可以知道使用 C T C CTC CTC时的目标就是最大化 p ( z ∣ x ) p(\mathbf{z}|\mathbf{x}) p(zx),可以定义损失函数为 − l o g ( p ( z ∣ x ) ) -log(p(\mathbf{z}|\mathbf{x})) log(p(zx)),可以推导损失的计算和损失函数梯度都能使用递推的方式来计算,减少运算量,加快运算速度。

2.4 推理时

训练完成后,在网络推理时希望取概率最大的输出序列:

z ∗ = a r g m a x z   p ( z ∣ x ) \mathbf{z}^* = \underset{\mathbf{z}}{argmax} \,p(\mathbf{z}|\mathbf{x}) z=zargmaxp(zx)

对所有路径的概率求和,然后取概率最大的路径作为预测的结果,应该是最合理的方式,但当序列比较长时面临计算量过大,影响推理速度的情况。

一种做法是对于第 t t t步,取概率最大的字符,然后将所有的字符组合起来经过去重当作最终的输出,但这种做法只考虑了一条路径,有可能有多条路径对应标签,各条路径的概率加和后有可能更大。

一种替代的折衷方法是改进版的Beam Search

常规的Beam Search算法,对于每个时间步取概率最大的几个(Beam Size)可能结果,如下为字母集为 − , a , b -,a,b ,a,bBeam Size=3Beam Search的过程:

在这里插入图片描述

上图中Beam Search到当前步最大的几个(Beam Size)可能字符都只有一条前缀序列,实际上可以有多条前缀序列和当前的字符组合后都得到相同的输出,如下图对于路径长度 T = 2 T=2 T=2 λ a \lambda a λa, a − a- a, a a aa aa最后都能对应的 a a a

a a a b a a a b b a ϵ a b ϵ a b ϵ a b λ a b ϵ a b ϵ a b ϵ a b λ ϵ a b T = 4 T = 3 T = 2 T = 1 current hypotheses proposed extensions current hypotheses proposed extensions current hypotheses proposed extensions current hypotheses Multiple extensions merge to the same prefix empty string

且观察 T = 3 T=3 T=3时,前缀序列 a a aa aa对应的输出有可能是 a a a或者 a a aa aa,因此对应的概率应该分别进行计算。

3.pytorch中的CTCLOSS

计算未切分的连续时间序列和目标序列之间的损失。

torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)

class CTCLoss:
     ...
     def forward(self, log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor) -> Tensor:
          ...
     
  • log_probs:Tensor of size (T,N,C)/(T,C),T是输入长度,N是Batch Size,C是序列字典的大小(包括空格)

  • targets:Tensor of size(N,S)Nbatch sizeS是最大目标序列长度,目标序列中的每个元素是类别的序号。

  • input_lengths,每个输入序列的长度,为元组tupleshape(N,)的张量,Nbatch sizeinput_lengths的值 ≤ T \le T T

  • target_lengths,每个目标序列的长度,为元组tupleshape(N,)的张量,Nbatch size,如果targetsshape(N,S),这里其实是把每个 t a r g e t target target添加padding后变成了S,假设第n个序列目标长度为 s n s_n sn,target_lengths中第n个元素值就为 s n s_n sn

import torch

T = 2
C = 3
N = 1
S = 2
S_min = 1

input = torch.randn(T,N,C).log_softmax(2).detach().requires_grad_()
print(input)
target = torch.tensor([0,1], dtype=torch.long).reshape(shape=(N, S))
print(target)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.tensor([2], dtype=torch.long).reshape(shape=(N,)) 
ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
print(loss)

# tensor([[[-0.4002, -1.5314, -2.1752]],         [[-0.8444, -2.2039, -0.7770]]], requires_grad=True)
# tensor([[0, 1]])
# tensor(1.3021, grad_fn=<MeanBackward0>)

上面示例的计算过程:

在这里插入图片描述

从上图可以看到目标是 01 01 01at路径有且仅有此一条,损失值计算为:

l o s s = − 1 2 [ − 0.4002 + ( − 2.2039 ) ] = 1.3021 loss = -\frac{1}{2}[-0.4002+(-2.2039)]=1.3021 loss=21[0.4002+(2.2039)]=1.3021


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


参考资料

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值