论文阅读:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
这篇论文介绍了一种名为Swin Transformer的新型视觉Transformer架构,它能够有效地作为计算机视觉任务的通用骨干网络。Swin Transformer通过使用移位窗口(Shifted Windows)来计算自注意力(Self-Attention),从而解决了从语言领域到视觉领域的Transformer适应过程中遇到的挑战,如视觉实体尺度的大变化和图像像素的高分辨率。
主要贡献和特点包括:
- 层次化Transformer架构:Swin Transformer通过合并图像块(patches)来构建层次化的特征图,这使得模型能够在不同尺度上进行建模,并具有与图像大小成线性关系的计算复杂度。
- 移位窗口方案:通过在连续的自注意力层之间移动窗口划分,Swin Transformer在保持非重叠窗口内高效计算的同时,允许跨窗口连接,增强了模型的建模能力。
3. Method
3.1. Overall Architecture
Swin Transformer架构的概览如图3所示,它展示了一个tiny版本(SwinT)。它首先通过一个patch分割模块将输入的RGB图像分割成不重叠的patch,与ViT类似。每个patch被视为一个“token”,其特征设置为原始像素RGB值的展开。在我们的实现中,我们使用了4×4的patch大小,因此每个patch的特征维度是4×4×3=48。在这个原始值特征上应用了一个线性嵌入层,将其投影到任意维度(记为C)。
在这些patch token上应用了若干个带有修改自注意力计算的Transformer块(Swin Transformer块)。Transformer块保持了token的数量( H 4 × W 4 \frac H4 \times \frac W4 4H×4W),并与线性嵌入层一起被称为“stage 1”。
为了产生分层表示,随着网络的深入,通过patch merging layers来减少token的数量。这通过2×2=4的倍数减少了token的数量(分辨率的2×下采样),输出维度被设置为2C。随后应用Swin Transformer块进行特征转换,分辨率保持在 H 8 × W 8 \frac H8 \times \frac W8 8H×8W。这个patch合并和特征转换的第一个块被称为“第2阶段”。这个过程在“第3阶段”和“第4阶段”重复,输出分辨率分别为 H 16 × W 16 \frac H{16} \times \frac W{16} 16H×16W和 H 32 × W 32 \frac H{32} \times \frac W{32} 32H×32W。这些阶段共同产生了一个层次化表示,其特征图分辨率与典型卷积网络(例如VGG和ResNet)相同。因此,所提出的架构可以方便地替换现有方法中的骨干网络,用于各种视觉任务。
Swin Transformer Block
Swin Transformer是通过将Transformer块中的标准多头自注意力(MSA)模块替换为基于移位窗口的模块(在第3.2节中描述)而构建的,其他层保持不变。如图3(b)所示,一个Swin Transformer块由一个基于移位窗口的MSA模块组成,接着是一个两层的MLP,中间有GELU非线性。在每个MSA模块和MLP之前应用LayerNorm(LN)层,并在每个模块之后应用残差连接。
3.2. Shifted Window based Self-Attention
标准的Transformer架构及其图像分类的适应版本都进行全局自注意力计算,其中计算一个token与所有其他token之间的关系。全局计算导致与token数量成二次方的复杂度,这使得它不适合许多需要大量token进行密集预测或表示高分辨率图像的视觉问题。
Self-attention in non-overlapped windows
为了高效建模,我们提出在局部窗口内计算自注意力。窗口被安排为非重叠地均匀划分图像。假设每个窗口包含
M
×
M
M\times M
M×M个patch,全局MSA模块和基于窗口的自注意力模块在
h
×
w
h \times w
h×wpatch的图像上的计算复杂度分别为:
Ω
(
M
S
A
)
=
4
h
w
C
2
+
2
(
h
w
)
2
C
,
Ω
(
W
−
M
S
A
)
=
4
h
w
C
2
+
2
M
2
h
w
C
,
\begin{aligned}&\Omega(\mathbf{MSA})=4hwC^2+2(hw)^2C,\\&\Omega(\mathbf{W-MSA})=4hwC^2+2M^2hwC,\end{aligned}
Ω(MSA)=4hwC2+2(hw)2C,Ω(W−MSA)=4hwC2+2M2hwC,
其中前者与patch数量
h
w
hw
hw成二次方关系,后者在
M
M
M固定时(默认设置为7)是线性的。全局自注意力计算对于大
h
w
hw
hw来说是不可承受的,而基于窗口的自注意力是可扩展的。
Shifted window partitioning in successive blocks
基于窗口的自注意力模块缺乏跨窗口的连接,这限制了其建模能力。为了在保持非重叠窗口高效计算的同时引入跨窗口连接,我们提出了一种移位窗口划分方法,它在连续的Swin Transformer块之间交替使用两种划分配置。
如图2所示,第一个模块使用从左上角像素开始的常规窗口划分策略,并将8×8特征图均匀划分为4×4(M=4)的2×2窗口。然后,下一个模块采用与前一层移位的窗口划分配置,将窗口从常规划分的窗口中移动 ( ⌊ M 2 ⌋ , ⌊ M 2 ⌋ ) (\lfloor\frac M2\rfloor,\lfloor\frac M2\rfloor) (⌊2M⌋,⌊2M⌋)像素。
通过移位窗口划分方法,连续的Swin Transformer块计算为:
z
^
l
=
W
−
M
S
A
(
L
N
(
z
l
−
1
)
)
+
z
l
−
1
,
z
l
=
M
L
P
(
L
N
(
z
^
l
)
)
+
z
^
l
,
z
^
l
+
1
=
S
W
−
M
S
A
(
L
N
(
z
l
)
)
+
z
l
,
z
l
+
1
=
M
L
P
(
L
N
(
z
^
l
+
1
)
)
+
z
^
l
+
1
,
\begin{aligned} &\hat{\mathbf{z}}^{l}=\mathrm{W-MSA}\left(\mathrm{LN}\left(\mathbf{z}^{l-1}\right)\right)+\mathbf{z}^{l-1}, \\ &\mathbf{z}^{l}=\mathbf{MLP}\left(\mathbf{LN}\left(\hat{\mathbf{z}}^{l}\right)\right)+\hat{\mathbf{z}}^{l}, \\ &{\hat{\mathbf{z}}}^{l+1}={\mathrm{SW-MSA}}\left({\mathrm{LN}}\left({\mathbf{z}}^{l}\right)\right)+{\mathbf{z}}^{l}, \\ &\mathbf{z}^{l+1}=\mathsf{MLP}\left(\mathsf{LN}\left(\hat{\mathbf{z}}^{l+1}\right)\right)+\hat{\mathbf{z}}^{l+1}, \end{aligned}
z^l=W−MSA(LN(zl−1))+zl−1,zl=MLP(LN(z^l))+z^l,z^l+1=SW−MSA(LN(zl))+zl,zl+1=MLP(LN(z^l+1))+z^l+1,
其中
z
^
l
\hat{\mathbf{z}}^{l}
z^l和
z
l
\mathbf{z}^{l}
zl分别表示块L的(S)WMSA模块和MLP模块的输出特征;W-MSA和SW-MSA分别表示使用规则和移位窗口分区配置的基于窗口的多头自关注。
Efficient batch computation for shifted configuration
移位窗口划分的一个问题是它会导致更多的窗口,从常规划分的 ⌈ h M ⌉ × ⌈ w M ⌉ \lceil\frac{{h}}M\rceil\times\lceil\frac wM\rceil ⌈Mh⌉×⌈Mw⌉增加到移位配置的 ( ⌈ h M ⌉ + 1 ) × ( ⌈ w M ⌉ + 1 ) (\lceil\frac{{h}}M\rceil+1)\times(\lceil\frac wM\rceil+1) (⌈Mh⌉+1)×(⌈Mw⌉+1),其中一些窗口的大小会小于 M × M M\times M M×M。一个简单的解决方案是对较小的窗口进行填充,使其达到M × M的大小,并在计算注意力时屏蔽掉填充的值。当常规划分的窗口数量较少时,例如2 × 2,这种简单解决方案带来的计算增加是相当大的(2 × 2 → 3 × 3,增加了2.25倍)。在这里,我们提出了一种更高效的批量计算方法,通过向顶部左侧方向循环移位,如图4所示。移位后,一个批量窗口可能由几个在特征图上不相邻的子窗口组成,因此采用一个屏蔽机制来限制自注意力计算仅在每个子窗口内进行。通过循环移位,批量窗口的数量保持与常规窗口划分相同,因此也是高效的。表5展示了这种方法的低延迟。
Relative position bias
Swin Transformer之相对位置编码详解 (zhihu.com)
在计算自注意力时,我们遵循[49, 1, 32, 33]的做法,通过在每个头部计算相似性时包含相对位置偏差
B
∈
R
M
2
×
M
2
B\in\mathbb{R}^{M^2\times M^2}
B∈RM2×M2
Attention
(
Q
,
K
,
V
)
=
SoftMax
(
Q
K
T
/
d
+
B
)
V
\operatorname{Attention}(Q,K,V)=\operatorname{SoftMax}(QK^T/\sqrt d+B)V
Attention(Q,K,V)=SoftMax(QKT/d+B)V
M
2
M^2
M2是一个窗口中的补丁的数量。由于沿每个轴的相对位置在
[
−
M
+
1
,
M
−
1
]
[−M+1,M−1]
[−M+1,M−1]的范围内,我们将较小尺寸的偏置矩阵
B
^
∈
R
(
2
M
−
1
)
×
(
2
M
−
1
)
\hat{B}\in\mathbb{R}^{(2M-1)\times(2M-1)}
B^∈R(2M−1)×(2M−1)参数化,并且B中的值取自
B
^
\hat{B}
B^。
我们观察到,与没有这个偏差项或使用绝对位置嵌入的对应方法相比,有显著的改进,如表4所示。在输入中进一步添加绝对位置嵌入,如[20]中所述,会略微降低性能,因此在我们实现中没有采用。预训练中学到的相对位置偏差也可以通过双三次插值[20, 63]用于初始化不同窗口大小的微调模型。
3.3. Architecture Variants
我们构建了基础模型,称为Swin-B,其模型大小和计算复杂度与ViTB/DeiT-B相似。我们还引入了Swin-T、Swin-S和Swin-L,它们分别是模型大小和计算复杂度约为ViTB/DeiT-B的0.25倍、0.5倍和2倍的版本。请注意,Swin-T和Swin-S的复杂度与ResNet-50 (DeiT-S) 和ResNet-101相似。默认情况下,窗口大小设置为M = 7。对于所有实验,每个头部的查询维度为d = 32,每个MLP的扩展层为α = 4。这些模型变体的架构超参数如下:
- Swin-T: C = 96, 层数 = {2, 2, 6, 2}
- Swin-S: C = 96, 层数 = {2, 2, 18, 2}
- Swin-B: C = 128, 层数 = {2, 2, 18, 2}
- Swin-L: C = 192, 层数 = {2, 2, 18, 2}
其中C是第一阶段隐藏层的通道数。这些模型变体在ImageNet图像分类任务中的模型大小、理论计算复杂度(FLOPs)和吞吐量列在表1中。