VLPT-STD
论文标题:Vision-Language Pre-Training for Boosting Scene Text Detectors
github:https://github.com/AlibabaResearch/AdvancedLiterateMachinery/tree/main/OCR/VLPT-STD
代码实验配置:https://blog.csdn.net/qq_44445108/article/details/133926242
摘要
利用视觉-语言联合表示学习的方法,涉及两种模态之间的跨模态交互。
提出了一个包含图像编码器、文本编码器和跨模态编码器的架构,三个前置任务:图像-文本对比学习(CLIP)ITC、掩码语言建模MIM、单词-图像内的预测WIP
引言
三个要解决的问题
为了加速训练过程,增强泛化能力,预训练技术也被大量应用到场景文本检测方法中。早期的方法都是在ImageNet上进行预训练,然后在要预测的数据集上进行微调。然而,来自ImageNet的自然图像和场景文本图像之间存在明显的域差距,这可能导致微调后的性能增益有限。
在SynthText合成数据集上预训练的文本检测模型优于在ImageNet上预训练的模型,但是仍然存在合成文本数据和真实文本数据之间的域差距,通常会导致文本纹理被误检。
STKM通过对图像编码,挖掘文本知识进行预训练,可以通过编码后的图像解码出文本。但是,STKM采用字符级解码过程,难以有效利用词库中的上下文信息,而且STKM本质上还是单向映射,从视觉模态到语言模态,并没有利用跨模态的交互学习信息表示。
作者的工作
提出了VLPT-STD,首先通过对比学习对齐图像和文本的单模态表示,然后通过掩码语言模型和单词-图像预测的预训练任务关注细粒度的文本区域。
主要流程就是分别从图像、文本编码器中提取图像、文本嵌入,然后通过各种预训练任务将两种模态的嵌入输入到联合编码器中进行细粒度的跨模态交互。
不同方法之间的区别
(a) 常规的方法仅提取图像特征,在SynthText上预训练,并在要推理的数据集上微调;
(b) STKM方法还是仅提取图像特征,和常规方法不同的是,预训练的输出不是文本框,而是将图像特征解码为文本,相当于换了一种预训练的方法;
© 作者提出的方法是通过图像特征和文本特征的交互(三个前置任务),增强图像编码器的主干。
其实这三类方法本质上都是一样的,即预训练的目的都是增强图像编码器的主干,增强提取文本图像特征的能力。
方法
模型架构
image encoder
包括ResNet、FPN和注意力池化层。
图像(512×512)输入到ResNet中,得到各个残差网络块的特征,记为C2、C3、C4、C5。先通过一个1×1的卷积层将这四个特征图进行变换,通道数都设置为256,P5为C5变换后的特征图,P2、P3、P4为通过FPN上采样变换特征图相加得到。将P2、P3、P4、P5特征图横向连接起来,再通过一个步长为2的1×1卷积层(proj),通道数从1024减少到384,特征图大小减小为原来的1/16。最后,通过注意力池化层来提取以图像全局平均池化表示为条件的视觉嵌入。
视觉嵌入的表示:
V
=
{
V
[
C
L
S
]
,
V
1
,
…
,
V
S
}
∈
R
d
\mathbf{V}=\{V_{[\mathrm{CLS}]},V_{1},…,V_{S}\}\in\mathbb{R}_{d}
V={V[CLS],V1,…,VS}∈Rd
V[CLS]表示Transformer模型中常用的特殊标记之一,这里代码图像序列的开头。
text encoder
包括三个多头自注意力模块
对于一个文本,先将其拆分为单词序列,然后通过WordPiece(BERT使用的分词法)将每个单词tokenize为tokens,再通过嵌入矩阵将tokens转换为文本嵌入。
代码中就是使用的BERT预训练好的分词器:
from transformers import (
DataCollatorForLanguageModeling,
DataCollatorForWholeWordMask,
BertTokenizer,
)
BertTokenizer.from_pretrained('bert-base-uncased')
文本嵌入:
W
=
{
W
[
C
L
S
]
,
W
1
,
W
2
,
⋯
,
W
K
}
∈
R
d
W=\{W_{[\mathrm{CLS}]},W_{1},W_{2},\cdots,W_{K}\}\in\mathbb{R}_{d}
W={W[CLS],W1,W2,⋯,WK}∈Rd
在这个文本嵌入的基础上,增加一个可训练的位置嵌入。
cross-modal encoder
包括四个相同的Transformer解码器层,每个解码器层由多头自注意力模块、多头交叉注意力模块和前馈网络组成。
在每个注意力模块前使用层归一化LN,在每个注意力模块之后使用残差连接;前馈网络包含两个MLP,采用GELU激活函数。
预训练任务
单模态编码器上的图文对比学习——ITC、图像中单词预测——WIP
跨模态编码器上的掩码语言建模——MLM
ITC
和CLIP的思想一模一样,旨在给定一个文本嵌入,找到一个最契合的图像嵌入;给定一个图像嵌入,找到一个最契合的文本嵌入,最大化成对图像-文本的余弦相似度,典型的对比学习思想。
和CLIP不同的是,CLIP的文本是描述图像的内容,ITC的文本是图像中包含文本的序列。
将SynthText数据集中的每个图文对表示为:
(
x
i
I
,
x
i
T
)
(x_{i}^{I},x_{i}^{T})
(xiI,xiT)
前一表示图像,后一表示文本。令I和T分别表示图像嵌入和文本嵌入,则一个批次内一个图像和所有文本的对比损失、一个文本和所有图像的对比损失分别为:
L
l
2
T
=
−
∑
j
log
exp
(
I
j
⋅
T
j
/
τ
)
∑
k
=
1
N
exp
(
I
j
⋅
T
k
/
τ
)
\mathcal{L}_{\mathrm{l2T}}=-\sum_{j}\log\frac{\exp\left(I_{j}\cdot T_{j}/\tau\right)}{\sum_{k=1}^{N}\exp\left(I_{j}\cdot T_{k}/\tau\right)}
Ll2T=−j∑log∑k=1Nexp(Ij⋅Tk/τ)exp(Ij⋅Tj/τ)
L T21 = − ∑ j log exp ( T j ⋅ I j / τ ) ∑ k = 1 N exp ( T j ⋅ I k / τ ) \mathcal{L}_{\text{T21}}=-\sum_{j}\log\frac{\exp{(T_{j}\cdot I_{j}/\tau)}}{\sum_{k=1}^{N}\exp{(T_{j}\cdot I_{k}/\tau)}} LT21=−j∑log∑k=1Nexp(Tj⋅Ik/τ)exp(Tj⋅Ij/τ)
则总的对比损失为:
L
I
T
C
=
λ
1
L
I
2
T
+
λ
2
L
T
2
I
\mathcal{L}_{\mathrm{ITC}}=\lambda_{1}\mathcal{L}_{\mathrm{I2T}}+\lambda_{2}\mathcal{L}_{\mathrm{T2I}}
LITC=λ1LI2T+λ2LT2I
当λ1=λ2=0.5时,效果最好。
WIP
利用图像嵌入和单词级文本嵌入之间的对比学习来区分图像中存在的单词(正样本)和图像中不存在的单词(负样本),预测单词在图像中的存在性。
在训练过程中基于文本嵌入的相似性对困难的负样本进行采样,从单词文本中捕获更细粒度的线索(基于OHEM思想)。比如正样本"lost",负样本"last"。
对于每个文本嵌入,采样其前L个最近邻的文本嵌入作为负例,通过对比学习降低正样本和负样本之间的相似度。
则每个图像-文本对 对比学习损失为:
L
W
I
P
=
−
∑
k
=
1
K
log
exp
(
I
⋅
W
k
/
τ
)
exp
(
I
⋅
W
k
/
τ
)
+
∑
l
=
1
L
exp
(
I
⋅
W
~
k
l
/
τ
)
\mathcal{L}_{\mathrm{WIP}}=-\sum_{k=1}^{K}\log\frac{\exp(I\cdot W_{k}/\tau)}{\exp(I\cdot W_{k}/\tau)+\sum_{l=1}^{L}\exp\left(I\cdot\widetilde{W}_{k}^{l}/\tau\right)}
LWIP=−k=1∑Klogexp(I⋅Wk/τ)+∑l=1Lexp(I⋅W
kl/τ)exp(I⋅Wk/τ)
MLM
MLM是一种基于Transformer模型的无监督学习方法,常用于预训练语言模型,如BERT。
根据BERT的做法,对输入的15%文本序列,替换为10%的随机字符、10%保持原样和80%[MASK]。对那些未[MASK]处理的文本序列和所有图像序列,通过最小化负对数似然来恢复那些mask的单词。
那么MLM的损失为:
L
M
L
M
=
−
E
(
W
,
V
)
log
P
θ
(
W
m
a
s
k
e
d
∣
W
u
n
m
a
s
k
e
d
,
V
)
\mathcal{L}_{\mathbf{MLM}}=-\mathbb{E}_{(W,V)}\log P_{\theta}(W_{\mathrm{masked}}|W_{\mathrm{unmasked}},\mathbf{V})
LMLM=−E(W,V)logPθ(Wmasked∣Wunmasked,V)
由于SynthText数据集中的word都是拼凑而成,并没有什么有用的上下文语义信息,因此这个论文中的MLM任务是依赖于图像内容来恢复被mask的word,这样可以更好地学习文本图像区域的特征。
总的损失函数:
L
=
L
r
I
C
+
L
W
I
P
+
L
M
I
M
\mathcal{L}=\mathcal{L}_{\mathrm{rIC}}+\mathcal{L}_{\mathrm{WIP}}+\mathcal{L}_{\mathrm{MIM}}
L=LrIC+LWIP+LMIM
实验
实验设置
数据集
在SynthText800K数据集上预训练。保留20K的图像进行验证。对于SynthText数据集而言,只使用文本标签预训练。
主要实验的数据集:
- Total-Text:弯曲文本
- CTW-1500:弯曲文本
- ICDAR2015:常用数据集
- ICDAR2017:9种不同语言的数据集
- MSRA-TD500
- TextOCR:庞大而多样的OCR数据集,图像中文本密度高
预训练设置
将图像大小调整为512×512,在-20°到20°范围内随机旋转,由于没有区域监督,没有做其他的数据增强的手段。
采用AdamW优化器,权重衰减为0.01。学习率前2.5K步不变,2.5K步后从1×10^(-4)线性衰减到0。
ITC和WIP中可学习的参数初始化为0.07并进行剪枝,防止logit缩放大于100,来提高训练稳定性。
视觉和文本嵌入的维度设置为384,视觉tokens和文本tokens的数量设置为1025和30。
评价方案
对输入图像进行了多种数据增强,包括水平翻转、旋转、缩放、随机裁剪和颜色抖动。在训练过程中,ICDAR2015、ICDAR2017和MSRA - TD500的图像大小调整为512 × 512、640 × 640、640 × 640,批大小分别为32、22和22。我们使用Adam优化器在ICDAR2015和MSRA - TD500数据集上对预训练的骨干网络进行了600次微调,基学习率为1 × 10-4,每200次衰减0.1。对于ICDAR2017,EAST训练300个epoch,初始学习率为1 × 10-4,之后每50epoch衰减0.1。
消融实验
网络结构
论文中的FPN为变体,减少了训练计算量,而且提了精度。
MHCA表示多头交叉注意力,w/o是表示只使用自注意力模块,可以看出交叉注意力模块是有作用的。
预训练任务
预训练数据集
ST表示SynthText,TO表示TextOCR。
定性分析结果
注意力图可视化
随着训练的进行,能更好地定位到文本区域。给定单词,能找到单词所在的图像位置。
检测结果可视化
左边一列是STKM的检测结果,右边一列是论文的检测结果。MLM和WIP可以抑制类文本区域上的误检。