论文链接: https://arxiv.org/pdf/2207.11463.pdf
代码地址:https://github.com/LBH1024/CAN
Abstract
我们为手写字识别设计了一个叫做CAN的网络,这个网络加入了两个优化任务:手写公式识别和符号计数。具体来说,我们设计了一个弱监督计数模型,这个模型不需要符号位置也能预测符号的个数,然后将其插入到编解码结构的手写公式识别模型中。在HMER的基准数据集上的实验验证表明,联合优化计数模型和手写公式模型,利于纠正编解码器模型的预测误差,并且CAN始终优于当前最先进的方法。特别是,与HMER的编解码器模型相比,所提出的计数模块所造成的额外时间成本是边际的。
1、Introduction
计数和HMER是两个互补的任务,使用计数可以提高HMER的性能。这种做法主要基于以下两方面的考虑:(1)符号计数能够提供字符级位置信息,这能使注意力更加准确;(2)计数结果可以表示符号的数量,可以作为额外的全局信息,来提高识别的准确性。
本文的主要贡献两个:(1)将符号计数引入HMER,并揭示了HMER与符号计数的相关性和互补性。2)我们提出了一种联合优化符号计数和HMER的新方法,提高了HMER的编解码器模型的性能。
2、Related Work
2.1、HMER
略
2.2、Object Counting
对象计数大致可以分为基于检测和基于回归两类。基于检测的方法通过检测每个实例来获得数字。基于回归的方法通过回归密度图来学习计数,预测的计数等于密度图的积分。为了提高计数精度,基于回归的方法广泛采用了多尺度策略、注意机制和视角信息。然而,基于检测和基于密度图回归的方法都需要对象位置注释,如方框和点的标注。为了减轻标记工作,一些方法提出只使用计数标注。他们发现可视化的特征图可以准确地反映物体的区域。与之前大多数专门针对类别的计数模块(例如,人群计数)不同,我们的计数模块是为多类对象计数而设计的,因为公式通常各种包含不同的符号。在OCR领域,Xie等提出了一个基于计数的损失函数,主要为场景文本(单词或文本行)设计,所以我们的模型也可以在特征水平和损失水平上利用更复杂文本的计数信息(如数学表达式)的信息。
3、 Methodology
3.1、Overview
图1 CAN模型架构
我们的计数感知网络主要包含:backbone、多尺度计数模块(MSCM)、结合计数的注意解码器 (CCAD)。和DWAP 一样,我们应用DenseNet作为backbone。对于一个输入灰度图像 H ′ × W ′ × 1 H^{'} \times W^{'} \times 1 H′×W′×1,首先使用DenseNet提取图像特征,输出维度为 H × W × 1 H \times W \times 1 H×W×1,其中 H ’ H = W ’ W = 16 \frac{H^{’}}{H}=\frac{W^{’}}{W}=16 HH’=WW’=16。MSCM和CCAD同时使用特征图,其中MSCM用于预测图像中符号的个数并生成一个一维的计数向量来表示计数的结果。将特征图F和计数向量输入CCAD,得到预测输出。
3.2、Multi-Scale Counting Module
图2 MSCM模型架构
MSCM由多尺度特征、通道注意力和全局平均池化层组成。由于不同的书写习惯,公式图像通常包含大小不同的符号,单一大小的卷积核不能处理这种变化,为了解决这个问题,我们首先提出了利用两个并行卷积分支,这两个并行卷积分支使用不同大小的卷积核(设置为3×3和5×5)来提取多尺度特征。在卷积层之后,采用通道注意来进一步增强特征信息。在这里,我们选择其中一个分支作为简单的说明。我们将H表示为从卷积(3×3或5×5)层中提取的特征图。增强的特征S可以写为:
Q
=
σ
(
W
1
(
G
(
H
)
)
+
b
1
)
S
=
Q
⊗
g
(
W
2
Q
+
b
2
)
Q=\sigma(W_{1}(G(H))+b_{1}) \\ S=Q \otimes g(W_{2}Q+b2)
Q=σ(W1(G(H))+b1)S=Q⊗g(W2Q+b2)
这里,G表示全局平均池化层,
σ
、
g
\sigma、g
σ、g表示RELU和sigmoid函数,
⊗
\otimes
⊗表示通道乘积,
W
1
、
W
2
、
W
3
W_{1}、W_{2}、W_{3}
W1、W2、W3表示可训练权重。
在得到增强的特征S后,我们使用1×1卷积将通道数从
c
o
c_{o}
co减少到C,其中C是符号类的数量。理想情况下,符号计数结果应主要从前景计算。因此,我们在1×1卷积之后使用sigmoid函数生成(0,1)范围内的计数特征图M。对于每个
M
i
M_{i}
Mi,它可以有效地反映第i个符号类的位置,如图1所示。从某种意义上说,每个
M
i
M_{i}
Mii实际上都是一个伪密度图,我们可以利用求和池化算子来获得计数向量
v
i
v_{i}
vi:
V
i
=
∑
p
=
1
H
∑
q
=
1
W
M
i
,
p
q
V_{i}=\sum^{H}_{p=1}\sum^{W}_{q=1}M_{i,pq}
Vi=p=1∑Hq=1∑WMi,pq
这里,
v
i
v_{i}
vi表示第i个字符的预测计数,值得注意的是,不同分支的特征图包含不同的尺度信息,且具有高度的互补性。因此,我们结合互补计数向量,并使用平均算子生成最终结果
V
f
V_{f}
Vf,然后将其输入解码器CCAD中。
3.3、Counting-Combined Attentional Decoder
图3 CCAD模型架构
给定二维特征图 F ∈ H × W × 684 F \in H×W×684 F∈H×W×684,我们首先使用1×1卷积来改变通道数,得到变换特征 T ∈ H × W × 512 T \in H×W×512 T∈H×W×512。然后,为了提高模型对空间位置的感知,我们使用固定的绝对位置编码 P ∈ H × W × 512 P\in H×W×512 P∈H×W×512来表示T中的不同空间位置。
在t时刻的解码阶段,我们将
y
t
−
1
y_{t-1}
yt−1的embeding输入到GRU中得到隐藏状态
h
t
∈
1
×
256
h_{t} \in 1 \times 256
ht∈1×256,将他和特征T、空间位置P结合,我们可以得到注意力权重
α
t
\alpha_{t}
αt:
e
t
=
w
T
t
a
n
h
(
T
+
P
+
W
a
A
+
W
h
h
t
)
α
t
,
i
j
=
e
x
p
(
e
t
,
i
j
)
/
∑
p
=
1
H
∑
q
=
1
W
e
t
,
p
q
e_{t} = w^{T}tanh(T+P+W_{a}A+W_{h}h_{t}) \\ \alpha_{t,ij}=exp(e_{t,ij})/\sum^{H}_{p=1}\sum^{W}_{q=1}e_{t,pq}
et=wTtanh(T+P+WaA+Whht)αt,ij=exp(et,ij)/p=1∑Hq=1∑Wet,pq
其中,A表示所有过去的注意力权重。
将空间权重乘积应用于注意权重
α
t
\alpha_{t}
αt和特征F,我们可以得到上下文向量
C
∈
1
×
256
C \in 1×256
C∈1×256。在之前的大多数HMER方法中,他们只使用上下文向量
C
C
C、隐藏状态
h
t
h_{t}
ht和
y
t
−
1
y_{t-1}
yt−1的embeding来预测
y
t
y_{t}
yt。实际上,C只是对应于特征图f的一个局部区域,我们认为ht和embeding也缺乏全局信息。考虑到计数向量
v
v
v是从全局计数的角度计算出来的,它可以作为额外的全局信息,使预测更加准确,我们将它们结合在一起来预测
y
t
y_{t}
yt如下:
p
(
y
t
)
=
s
o
f
t
m
a
x
(
W
o
T
(
W
c
C
+
W
v
v
+
W
t
h
t
+
W
e
E
)
+
b
o
)
y
t
∼
p
(
y
t
)
p(y_{t})=softmax(W^{T}_{o}(W_{c}C+W_{v}v+W_{t}h_{t}+W_{e}E)+b_{o}) \\ y_{t} \sim p(y_{t})
p(yt)=softmax(WoT(WcC+Wvv+Wtht+WeE)+bo)yt∼p(yt)
3.4、Loss Function
整体损失函数由两部分组成,定义如下:
L
=
L
c
l
s
+
L
c
o
u
n
t
i
n
g
L = L_{cls} + L_{counting}
L=Lcls+Lcounting
其中,
L
c
l
s
L_{cls}
Lcls是预测概率
p
(
y
t
)
p(y_{t})
p(yt)分类损失,常使用交叉熵分类损失函数。
L
c
o
u
n
t
i
n
g
L_{counting}
Lcounting表示符号的计数损失,采用
s
m
o
o
t
h
L
1
smooth L1
smoothL1回归损失。
4、Experiments
4.1、Datasets
略
4.2、Implementation Details
- Nvidia Tesla V100 32GB
- batch size = 8
- Adadelta优化器
- 学习率先从0增加到1然后使用cosine schedules学习策略
- epoch=240
- 忽略六类符号标签:
sos
、eos
、{
、}
、^
、_
的计数值,将其设置为0(论文认为这些符号为隐藏符号,会影响准确率)
4.3 Evaluation Metrics
CAN模型的结果比较:
表1 CROHME数据集上的结果
4.4 Comparison with State-of-the-Art
略
4.5 Results on the HME100K Dataset
表2 HME100K数据集上的结果