SAM论文翻译

论文地址:https://arxiv.org/pdf/2308.10493v1.pdf
代码地址:https://github.com/liuzhuang1024/SAM

Abstract

目前的方法不能明确地学习不同符号之间的相互作用,当面对相似的符号时可能会识别错误。为了缓解这个问题,我们提出了一种简单而有效的方法来增强语义交互学习(SIL)。具体来说,我们首先构造了一个基于统计符号共现概率的语义图。然后设计了一个语义感知模块(SAM),它将视觉和分类特征投射到语义空间中。不同投影向量之间的余弦距离表示符号之间的相关性。联合优化HMER和SIL可以提高模型对符号关系的理解。此外,SAM可以很容易地插入到现有的基于注意力的HMER模型中。

在这里插入图片描述

1、Introduction

本文的主要贡献简要总结如下:

  • 据我们所知,我们是第一个使用共现来表示数学表达式中符号之间的关系,并验证增强语义表示学习的有效性。
  • 我们提出了一种语义感知方法,联合优化符号关系学习和HMER,可以持续提高HMER模型的性能。
  • 我们提出的语义感知模块可以很容易地插入到基于注意力的HMER模型中,并且在推理阶段无需额外的计算。

具体来说,我们采用DWAP 作为基线网络。在SAM的帮助下,SAM-DWAP在2014年、2016年和2019年的CROHME平台上的表现分别超过了DWAP的2.2%、2.8%和4.2%。此外,采用最新的SOTA方法CAN作为基线网络,我们的方法获得了新的SOTA结果(CROHME 2014为58.0%,CROHME 2016为56.7%,CROHME 2019为58.0%)。这表明,我们的方法可以推广到现有的各种HMER编解码器模型中,并提高其性能。

2、Related Work

3、Methodology

我们的方法的总体框架如图所示。该管道包括以下几个部分:采用DenseNet作为编码器来提取特征。输入DenseNet一个大小为H×W×1的灰度图像X,其中H和W分别为图像高度和图像宽度,并返回一个2D特征图 F ∈ R H ′ × W ′ × 684 F \in \R^{H^{'} \times W ^{'} \times 684} FRH×W×684,其中H/H‘=W/W’=16。解码器使用特征映射,并逐步预测Latex标记。语义感知模块(SAM)由两个结构相似的分支(视觉分支和分类分支)组成,它们分别采用了视觉特征和分类特征。将视觉特征和分类特征分别投影到语义空间中,得到投影的视觉向量和分类向量。计算不同时间步长的投影向量之间的余弦距离表示它们之间的相关性。

在这里插入图片描述

3.1、Semantic Graph

捕获全局上下文信息已被证明是提高识别的鲁棒性的有效方法。然而,与单词相比,在数学表达式中使用符号相对比较随意。如何在数学表达式中表达不同符号之间的关系是一个有待解决的问题。我们的直觉是,共现图中的值的大小反映了不同符号之间的关系,就像文本中不同的字符有不同的排列一样。

语义图被定义为G =(S,E),其中S = {s1,s2,…,sN }表示符号节点的集合,E表示边,表示任意两个符号之间的依赖性。图G的相关矩阵 R = { r i , j } i , j = 1 N R=\{r_{i,j}\}^{N}_{i,j=1} R={ri,j}i,j=1N包含与每条边相关的非负权值。相关矩阵是一个条件概率矩阵, r i j r_{ij} rij设为 P ( s i / s j ) P(si/sj) P(si/sj),其中P通过训练集计算。然而,R是一个非对称矩阵,即rij = rji。为了便于计算,我们将非对称矩阵转化为一个对称矩阵如下:
R ′ = 1 2 ( R + R T ) R^{'}=\frac{1}{2}(R+R_{T}) R=21(R+RT)

3.2、Semantic Aware Module

如图3所示,语义感知模块(SAM)包含两个分支,即视觉特征分支和分类特征分支。每个分支包括两个“LBR”块,然后是一个线性层。通过叠加线性层、批量归一化和ReLU激活来建立一个“LBR”块。我们使用SAM模块将视觉向量 v v i s v_{vis} vvis和分类向量 v c l s v_{cls} vcls投影语义空间中,得到映射后的视觉向量 v v i s ′ v^{'}_{vis} vvis和分类向量 v c l s ′ v^{'}_{cls} vcls,计算公式如下:
v v i s ′ = W 3 v i s ( σ ( ϵ ( W 2 v i s ( σ ( ϵ ( W 1 v i s v v i s + b 1 v i s ) ) ) ) + b 2 v i s ) ) ) + b 3 v i s v c l s ′ = W 3 c l s ( σ ( ϵ ( W 2 c l s ( σ ( ϵ ( W 1 c l s v c l s + b 1 c l s ) ) ) ) + b 2 c l s ) ) ) + b 3 c l s v^{'}_{vis} =W^{vis}_{3}( \sigma (\epsilon (W^{vis}_2(\sigma (\epsilon (W^{vis}_1v_{vis}+b^{vis}_1))))+b^{vis}_2)))+b^{vis}_3 \\ v^{'}_{cls} =W^{cls}_{3}( \sigma (\epsilon (W^{cls}_2(\sigma (\epsilon (W^{cls}_1v_{cls}+b^{cls}_1))))+b^{cls}_2)))+b^{cls}_3 vvis=W3vis(σ(ϵ(W2vis(σ(ϵ(W1visvvis+b1vis))))+b2vis)))+b3visvcls=W3cls(σ(ϵ(W2cls(σ(ϵ(W1clsvcls+b1cls))))+b2cls)))+b3cls
在这里插入图片描述

我们的目标是优化投影的视觉向量 v v i s ′ v^{'}_{vis} vvis和投影的分类向量 v c l s v_{cls} vcls。通过计算 c o s ( v i ′ , v j ′ ) cos(v^{'}_{i},v^{'}_{j}) cos(vivj)使所有i,j都接近相关矩阵 R i j R_{ij} Rij,其中 c o s ( v i ′ , v j ′ ) cos(v^{'}_{i},v^{'}_{j}) cos(vivj)表示 v i ′ v^{'}_{i} vi v j ′ v^{'}_{j} vj之间的余弦相似性:
c o s ( v i ′ , v i ′ ) = v ′ T v j ∥ v ′ T ∥ ∥ v j ∥ cos(v^{'}_{i},v^{'}_{i}) = \frac{{v^{'}}^{T}v_{j}}{\left \| {v^{'}}^{T}\right \| \left \|v_{j}\right \|} cos(vi,vi)= vT vjvTvj

3.3、Decoder

上图3为解码器的结构。该解码器主要包含两个门控循环单元(GRU)单元和一个注意模块。第一个GRU以符号嵌入 E ( y t − 1 ) E(y_{t-1}) E(yt1)和最后一步预测的历史状态 h t − 1 h_{t-1} ht1作为输入,输出一个新的隐藏状态向量 h t ′ h^{'}_{t} ht
h t ′ = G R U ( E ( y t − 1 ) , h t − 1 ) h^{'}_{t} =GRU(E(y_{t-1}),h_{t-1}) ht=GRU(E(yt1),ht1)
然后,注意模块通过其注意机制计算出注意权重 α t \alpha _{t} αt
e t = W ω ( t a n h ( W h ′ h t ′ + W f F + W α a t t t ) ) α t = e x p ( e t ) / ∑ e x p ( e t ) e_{t} = W_{\omega}(tanh(W^{'}_{h}h^{'}_{t}+W_{f}F+W_{\alpha}att_{t})) \\ \alpha_{t}=exp(e_{t})/\sum exp(e_{t}) et=Wω(tanh(Whht+WfF+Wαattt))αt=exp(et)/exp(et)
F表示特征图,attt表示覆盖注意,等于过去所有注意概率之和:
a t t t = ∑ i α i , i ∈ [ o , t − 1 ] att_{t}=\sum_{i}\alpha_{i},i\in[o,t-1] attt=iαi,i[o,t1]
α t \alpha _{t} αt和F相乘得到视觉特征向量 v v i s v_{vis} vvis
v v i s = α t ⨂ F v_{vis}=\alpha_{t}\bigotimes F vvis=αtF
第二个GRU将 v v i s v_{vis} vvis h t ′ h^{'}_{t} ht作为输入,返回隐藏状态 h t h_{t} ht
KaTeX parse error: Can't use function '$' in math mode at position 20: …t}=GRU(v_{vis},$̲h^{'}_{t})
然后将 E ( y t − 1 ) E(y_{t-1}) E(yt1) v v i s v_{vis} vvis h t h_{t} ht聚合,得到分类特征向量和符号概率:
v c l s = W e E ( y t − 1 ) + W h h t + W v V v i s p s y m b o l = s o f t m a x ( W s v c l s ) v_{cls}=W_{e}E(y_{t-1})+W_{h}h_{t}+W_{v}V_{vis} \\ p_{symbol} = softmax(W_{s}v_{cls}) vcls=WeE(yt1)+Whht+WvVvispsymbol=softmax(Wsvcls)

3.4、Loss Function

整体函数由三个部分组成,定义如下:
L = L s y m b o l + L v i s + L c l s \mathcal{L} = \mathcal{L}_{symbol} + \mathcal{L}_{vis} + \mathcal{L}_{cls} L=Lsymbol+Lvis+Lcls
其中$ \mathcal{L}{symbol} 符号是预测与标签的交叉熵分类损失。 符号是预测与标签的交叉熵分类损失。 符号是预测与标签的交叉熵分类损失。\mathcal{L}{vis} 和 和 \mathcal{L}_{cls}$为L2回归损失定义如下:
L v i s = ∑ i n ∑ j n ( c o s ( v v i s , i , v v i s , j ) − R i , j ) 2 L c l s = ∑ i n ∑ j n ( c o s ( v c l s , i , v c l s , j ) − R i , j ) 2 \mathcal{L}_{vis} =\sum^{n}_{i}\sum^{n}_{j}(cos(v_{vis,i},v_{vis,j})-R_{i,j})^{2} \\ \mathcal{L}_{cls}=\sum^{n}_{i}\sum^{n}_{j}(cos(v_{cls,i},v_{cls,j})-R_{i,j})^{2} Lvis=injn(cos(vvis,i,vvis,j)Ri,j)2Lcls=injn(cos(vcls,i,vcls,j)Ri,j)2

4、Experiments

4.1、Datasets

在这里插入图片描述

4.2、Implementation Details

使用一个32GB内V100进行实验。批处理大小设置为8。两个gru的隐藏状态大小和单词嵌入的维数均设为256。在训练过程中使用 Adadelta优化器,其中ρ设置为0.95,ϵ设置为10−6。学习速率从0开始,并在第一个阶段结束时单调地增加到1。之后,随着余弦学习时间表衰减到0。对于CROHME数据集,epoch设置为240,而对于HME100K数据集,epoch设置为40。

4.3、Evaluation Protocol

4.4、Comparison with State-of-the-Art

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值