目录
📝论文下载地址
🔨代码下载地址
[GitHub-official-Pytorch]
[GitHub-unofficial-Pytorch]
👨🎓论文作者
📦模型讲解
[背景介绍]
Transformer相关背景见[Transformer]。Transformer在处理计算机视觉任务取得不错的效果,但是始终没能超过卷积神经网络的相关方法。例如,针对图像识别任务的ViT网络,见下图。

另外,针对图像目标检测任务的DETR网络,见下图。

目前大多数的针对CV的Transformer方法都是首先将输入图像拆分为补丁,补丁的处理方式与 NLP 应用程序中相同。然后使用几个自监督层进行全局的信息交流,提取特征进行分类。直到谷歌提出NesT的模型,才使得Transformer方法在CV领域成为SOTA。
[模型解读]
在本节中,作者提出了一种高效的ResT模型,将Transformer用于图像识别任务,并拥有不错的速度。
[总体结构]
ResT的总体结构与ResNet相似,都是有4个stage。每一个stage中有4部分组成,如下图所示。

[标准Transformer]
标准的Transformer主要包含两个结构,MSA与FFN,每个结构都会有跨层链接,归一化采用LayerNorm的方法。假设输入Transformer的特征
x
∈
R
n
×
d
m
x\in \mathbb{R}^{n \times d_m}
x∈Rn×dm,
n
n
n为空间维度,
d
m
d_m
dm为通道维度。对于标准的Transformer其输出为:
y
=
x
′
+
F
F
N
(
L
N
(
x
′
)
)
,
a
n
d
x
′
=
x
+
M
S
A
(
L
N
(
x
)
)
y=x'+FFN(LN(x')),\quad and \quad x'=x+MSA(LN(x))
y=x′+FFN(LN(x′)),andx′=x+MSA(LN(x))
假设MSA中仅有一个head,那么
d
k
=
d
m
d_k=d_m
dk=dm,其中MSA 的计算成本为
O
(
2
d
m
n
2
+
4
d
m
2
n
)
O(2d_mn^2+4d_m^2n)
O(2dmn2+4dm2n)。在FFN中,一般
d
f
=
4
d
m
d_f=4d_m
df=4dm,FFN计算将花费
8
n
d
m
2
8nd_m^2
8ndm2。

[EMSA]
对于标准的Transformer,MSA有两个缺点:
(1)MSA尺度的计算是根据输入维度
d
m
d_m
dm与
n
n
n的二次方计算的,导致训练和推理的开销巨大;
(2)MSA中的每一个head只负责输入的一个子集,这可能会损害网络的性能,特别是当每个子集中的通道维度
d
k
d_k
dk太低时,使得Q和K的点积不能够构成一个信息匹配函数。
为了解决这些问题,作者提出了一种高效的多头自注意力模块,如下图所示。

(1)与MSA类似,EMSA首先获得查询向量Q。
(2)为了减少存储占用,将二维输入
x
∈
R
n
×
d
m
x\in \mathbb{R}^{n \times d_m}
x∈Rn×dm重构为三维
x
^
∈
R
d
m
×
h
×
w
\hat x \in \mathbb{R}^{d_m \times h \times w}
x^∈Rdm×h×w,之后送入卷积神经网络进行s倍的下采样。这里作者设置
s
=
8
/
k
s=8/k
s=8/k。卷积的步长为s,卷积核大小为为s+1,padding为s/2。
(3)在空间维度缩小后
x
^
∈
R
d
m
×
h
/
s
×
w
/
s
\hat x \in \mathbb{R}^{d_m \times h/s \times w/s}
x^∈Rdm×h/s×w/s再重构为二维向量
x
^
∈
R
n
′
×
d
m
,
n
′
=
h
/
s
×
w
/
s
\hat x \in \mathbb{R}^{n' \times d_m},n'=h/s \times w/s
x^∈Rn′×dm,n′=h/s×w/s,通过
x
^
\hat x
x^获得键向量K与值向量V。
(4)之后根据下式进行自注意运算。
E
M
S
A
(
Q
,
K
,
V
)
=
I
N
(
S
o
f
t
m
a
x
(
C
o
n
v
(
Q
K
T
d
k
)
)
)
V
EMSA(\bm{Q},\bm{K},\bm{V})=IN(Softmax(Conv(\frac{\bm{Q}\bm{K}^T}{\sqrt{d_k}})))\bm{V}
EMSA(Q,K,V)=IN(Softmax(Conv(dkQKT)))V
其中Conv()为标准的1×1卷积操作,IN()表示Instance Normalization。
(5)最后,将所有head特征串联生成最后的特征。
相对标准的MSA,EMSE的计算复杂度为
O
(
2
d
m
n
2
s
2
+
2
d
m
2
n
(
1
+
1
s
2
)
+
d
m
n
(
s
+
1
)
2
s
2
+
k
2
n
2
s
2
)
O(\frac{2d_mn^2}{s^2}+2d_m^2n(1+\frac{1}{s^2})+d_mn\frac{(s+1)^2}{s^2}+\frac{k^2n^2}{s^2})
O(s22dmn2+2dm2n(1+s21)+dmns2(s+1)2+s2k2n2),假设s>1则会低于MSA。
此外,在 EMSA 之后添加 FFN 以进行特征转换和非线性。每个高效 Transformer 模块的输出为:
y
=
x
′
+
F
F
N
(
L
N
(
x
′
)
)
,
a
n
d
x
′
=
x
+
E
M
S
A
(
L
N
(
x
)
)
y=x'+FFN(LN(x')),\quad and\quad x'=x+EMSA(LN(x))
y=x′+FFN(LN(x′)),andx′=x+EMSA(LN(x))
[Patch Embedding/切片嵌入]
以ViT为例,网络将输入图片
x
∈
R
3
×
h
×
w
x\in \mathbb R^{3 \times h \times w}
x∈R3×h×w裁剪为多个
s
×
s
s \times s
s×s的Patch,之后进行嵌入编码生成
x
∈
R
n
×
c
x \in \mathbb R^{n \times c}
x∈Rn×c,
n
=
h
w
/
p
2
n=hw/p^2
n=hw/p2,c为编码的维度。在作者提出的ResT网络中,首先输入stem模块进行4×4的下采样,作者采用的是三个卷积层进行放缩,步长分别为2、1、2。前两层使用BatchNorm与ReLU激活函数。
在stage2、3、4的Patch Embedding中,会对特征空间维度进行4倍下采样,通道维度变为2倍,通过步长为2、填充为1的标准3×3卷积来完成。
[Position Encoding/位置编码]
以ViT为例,将一组可学习的参数
θ
\theta
θ作为位置编码。即编码后输出为:
x
^
=
x
+
θ
\hat x =x+\theta
x^=x+θ 由此导致输入网络的特征尺寸不能变化。
本文作者假设参数
θ
\theta
θ与输入
x
x
x有关,即
θ
=
G
L
(
x
)
\theta=GL(x)
θ=GL(x):
x
^
=
x
+
G
L
(
x
)
\hat x =x+GL(x)
x^=x+GL(x)
θ
\theta
θ也可以通过更灵活的注意力机制获得。作者使用逐像素注意力(PA)模块来进行位置编码。PA 使用3×3卷积来获得像素级权重,然后通过sigmoid函数
σ
(
⋅
)
\sigma(\sdot )
σ(⋅)进行缩放。 PA模块的位置编码可以表示为:
x
^
=
P
A
(
x
)
=
x
×
σ
(
D
W
C
o
n
v
(
x
)
)
\hat x =PA(x)=x\times \sigma(DWConv(x))
x^=PA(x)=x×σ(DWConv(x)) 将PA位置编码与Patch Embedding结合结构如下。

[Classification Head/分类头]
分类头采用全局平均池化与全连接层。网络的结构见下表。Conv-k_c_s表示卷积层,卷积核大小为k,输出通道数为c,步长为s。MLP_c表示中间特征维度为4c,输入输出特征维度为c。MCSA_n_r表示EMSA,n为head数目,r为下采样倍数。

[结果分析]
[图像分类-ImageNet-1k]

[目标检测-COCO]

[实例分割-COCO]

[消融实验]
[模型类别]

[EMSA结构]

[位置编码]
