论文:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
代码:https://github.com/microsoft/Swin-Transformer
目录
1 主要贡献
- Transformer架构在NLP领域已成为通行架构,但在CV领域应用尚不广泛,主要有两点原因:
- 语言处理以单词为基本元素,其尺度不变,而视觉中的基本元素可以有明显的尺度变化;
- 图片中像素的分辨率远高于文本中的单词;
- Swin Transformer针对以上问题,借鉴了CNN的经验,构建了具有多层特征图的Transformer骨干网络,其计算复杂度与图片尺寸成线性关系;
- Swin Transformer引入了连续的自注意力层的窗口分割的“移动”机制,解决了滑动窗口的延迟问题;
- Swin Transformer在图像分类、物体检测和语义分割任务中都取得了出色的结果。
2 原理
2.1 总体架构
上图为Swin-T(tiny)的结构。输入的RGB图片先先被分割为不重叠的patches,每个patch大小为4*4,特征维度为48。阶段一经过线性嵌入和两个Swin Transformer block,将每个patch映射到长度为C的向量。阶段二中将阶段一的输出中2*2的相邻patches进行聚合,经过线性层后输出长度为2C的向量,再使用两个Swin Transformer block进行特征处理。阶段三和阶段四的处理方式类似。这样分层可以获得与经典的CNN分辨率相同的特征图。
Swin Transformer block
上图表示两个连续的Swin Transformer blocks,其中W-MSA和SW-MSA分别表示常规的和移动窗口的多头自注意力模块。每个Swin Transformer block是用基于移动窗口的MSA替换标准的MSA得到的。
2.2 基于移动窗口的自注意力
2.2.1 不重叠窗口中的自注意力
全局自注意力的计算复杂度与图片大小成二次关系,因此不适用于密集预测和高分辨率图像问题。因此本文提出,在局部窗口中计算自注意力,每个窗口包含M*M个patches。此时对于大小为h*w的图片,全局和基于窗口的自注意力的计算复杂度分别为:
Ω
(
MSA
)
=
4
h
w
C
2
+
2
(
h
w
)
2
C
Ω
(
SW-MSA
)
=
4
h
w
C
2
+
2
M
2
h
w
C
\Omega(\text{MSA})=4hwC^2+2(hw)^2C\\ \Omega(\text{SW-MSA})=4hwC^2+2M^2hwC
Ω(MSA)=4hwC2+2(hw)2CΩ(SW-MSA)=4hwC2+2M2hwC
前者与
h
w
hw
hw成二次关系,后者与
h
w
hw
hw成一次关系。
2.2.2 连续block中的窗口划分
不重叠的窗口中的自注意力计算缺乏跨窗口连接,因此在两个连续的Swin Transformer blocks中使用移动窗口的划分方法,如图:
相邻两层的窗口之间有
(
⌊
M
2
⌋
,
⌊
M
2
⌋
)
(\lfloor\frac M2\rfloor,\lfloor\frac M2\rfloor)
(⌊2M⌋,⌊2M⌋)的重叠。则连续的consecutive Swin Transformer blocks计算方法为:
z
^
l
=
W-MSA
(
LN
(
z
l
−
1
)
)
+
z
l
−
1
z
l
=
MLP
(
LN
(
z
^
l
)
)
+
z
^
l
z
^
l
+
1
=
SW-MSA
(
LN
(
z
l
)
)
+
z
l
z
l
+
1
=
MLP
(
LN
(
z
^
l
+
1
)
)
+
z
^
l
+
1
\hat{\mathbf{z}}^l=\text{W-MSA}(\text{LN}(\mathbf{z}^{l-1}))+\mathbf{z}^{l-1}\\ \mathbf{z}^l=\text{MLP}(\text{LN}(\hat{\mathbf{z}}^l))+\hat{\mathbf{z}}^l\\ \hat{\mathbf{z}}^{l+1}=\text{SW-MSA}(\text{LN}(\mathbf{z}^{l}))+\mathbf{z}^{l}\\ \mathbf{z}^{l+1}=\text{MLP}(\text{LN}(\hat{\mathbf{z}}^{l+1}))+\hat{\mathbf{z}}^{l+1}
z^l=W-MSA(LN(zl−1))+zl−1zl=MLP(LN(z^l))+z^lz^l+1=SW-MSA(LN(zl))+zlzl+1=MLP(LN(z^l+1))+z^l+1
2.2.3 针对移动机制的高效批计算
直接采用移动窗口会导致窗口数量增加,且一部分窗口较小;而如果使用padding补全每个窗口,会导致计算量大幅增加。因此本文提出循环移动方法,将较小的不相邻的窗口拼接在一起处理,这样每层处理的窗口数量相同。原理如图:
2.2.4 相对位置偏差
Attention
(
Q
,
K
,
V
)
=
SoftMax
(
Q
K
T
/
d
+
B
)
V
\text{Attention}(Q,K,V)=\text{SoftMax}(QK^T/\sqrt{d}+B)V
Attention(Q,K,V)=SoftMax(QKT/d+B)V
计算attention时加入相对位置偏差
B
∈
R
M
2
×
M
2
B\in\mathbb{R}^{M^2\times M^2}
B∈RM2×M2,其数值取自
B
^
∈
R
(
2
M
−
1
)
×
(
2
M
−
1
)
\hat{B}\in\mathbb{R}^{(2M-1)\times(2M-1)}
B^∈R(2M−1)×(2M−1)。使用相对位置偏差的效果优于不加入位置信息或使用绝对位置嵌入。
2.3 模型变种
窗口大小 M = 7 M=7 M=7,每个head的维度为 d = 32 d=32 d=32,每个MLP的扩展层为 α = 4 \alpha=4 α=4。四个模型变种的架构超参数如下:
3 实验
3.1 图像分类 ImageNet-1K
分别使用ImageNet-1K训练的模型和ImageNet-22K预训练后微调的模型。在两种实验设置下,Swin Transformer的表现均优于模型规模相近的SOTA模型,如图:
3.2 物体检测 COCO
分别对比框架、骨干网络和系统级的性能。在所有对应的对比实验中,Swin Transformer都取得了最优的表现,如图:
3.3 语义分割 ADE20K
使用Swin Transformer作为骨干网络,取得了最优的语义分割效果,如图:
3.4 消融研究
是否使用移动窗口的效果对比,和位置嵌入方式的对比。使用移动窗口和相对位置嵌入的模型效果最好,如图:
不同窗口关联方式的对比。移动窗口优于滑动窗口等其他方法,如图: