论文阅读笔记:Swin Transformer
前言
论文原文:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
源码地址:https://github.com/microsoft/Swin-Transformer
本篇论文作者试图扩展Transformer的适用性,使其让NLP可以和CNNs在视觉中一样,作为计算机视觉的通用backbone。但是将Transformer在语言领域的高性能转移到视觉领域的重大挑战主要体现在两种模式的两个差异上:
1.规模:与word tokens不同,视觉元素在规模上可能有很大差异,这是一个在目标检测等任务中受到关注的问题。在现有的基于Transformer的模型中,token都是固定比例的,这一属性不适合视觉应用。
2.计算量:与文本段落中的单词相比,图像中的像素分辨率要高得多。存在许多视觉任务,例如需要在像素级进行密集预测的语义分割,而这对于高分辨率图像上的Transformer来说是很困难的,因为自注意力的计算复杂性是图像大小的二次方。
为了克服这些问题,作者提出了一种通用的backbone,称为Swin Transformer,它可以构造分层特征映射,并且计算复杂度与图像大小成线性关系。如图所示:
可以看出Swin Transformer更贴近于传统的具有多尺寸特征图的CNNs backbone,对图像分别下采样4倍,8倍以及16倍。而Vision Transformer(ViT)一直下采样16倍。多尺寸的特征图有利于不同尺寸大小目标的分割和检测任务(解决了差异1带来的问题)。
Swin Transformer的另一个创新点为将连续的self-attention layers划分成多个窗口,每个窗口单独进行muti-head self-attetion(后文中均简称为MSA)计算,文中称为windows muti-head self-attetion(后文中均简称为W-MSA)计算。在W-MSA后会通过shift window的方式将窗口滑动,使得不同window之间可以信息交互,文中成为shifted windows muti-head self-attetion(后文中均简称为SW-MSA)计算,如下图。通过W-MSA和SW-MSA,成功的降低了计算量并且也保证了各个window之间的信息交互,保证了全局视野(解决了差异2带来的问题)。
那么上述的两种方法具体是如何实现的呢?下文会通过介绍Swin Transformer的网络结构来详细讲解。
网络结构
详细的网络结构参数如下表:
上表表示有Swin-T、 Swin-S、 Swin-B、 Swin-L四种尺寸的网络结构,C代表stage1之后的输出通道:
论文中给出的是Swin Transformer的Swin-T的网络结构,如图(a),其中的Swin Transformer Block如图(b),即将VIT attention Encoding Block中的MSA换成了W-MSA和SW-MSA:
首先将
H
×
W
×
3
H×W×3
H×W×3 的图片输入到Patch Partition中进行分块。每个Patch的大小为
4
×
4
×
3
=
48
4×4×3=48
4×4×3=48,通过Patch Partition后 shape 从
H
×
W
×
3
H×W×3
H×W×3 变为
H
4
×
W
4
×
48
\frac{H}{4} ×\frac{W}{4}×48
4H×4W×48 ,然后在通过Linear Embeding层对每个像素的channel数据做线性变换,变为
H
4
×
W
4
×
C
\frac{H}{4} ×\frac{W}{4}×C
4H×4W×C。接着通过4个Stage进行下采样,除了Stage1是Linear Embedding加一对Swin Transformer Block,其他三个Stage都是一个patch merging加若干对Swin Transformer Block。Patch Partition加Stage1的Linear Embedding,这个过程类似于VIT中的Linear Projection of Flattened Patches操作,即patch embedding过程(可查看博文【论文阅读笔记:Vision Transformer】和博文【Vision Transformer(Pytorch版)代码阅读注释】了解)。一对Swin Transformer Block如图(b)中的包含W-MSA的Block和SW-MSA Block,所以Stage中的Block都是2的倍数。
网络细节
Patch Merging
Patch Merging类似于YoloV5中的Focus模块(Focus模块的介绍可查看博文【从YOLOv5源码yolo.py详细介绍Yolov5的网络结构】),只不过在Patch Merging模块在Focus模块之后再进行LayerNorm和通道上的全连接,使得
H
=
H
0
2
,
W
=
W
0
2
,
C
=
C
0
×
2
H=\frac{H_0}{2} ,W=\frac{W_0}{2},C=C_0×2
H=2H0,W=2W0,C=C0×2。如下图:
W-MSA
引入W-MSA的目的是为了减少计算量,但同时也会使得window之间无法进行信息交互。
论文中提到一个
h
×
w
×
C
h×w×C
h×w×C 的 特征图,MSA 的计算量为公式(1),拆分成 window 宽高均为 M 以后的 W-MSA计算量为公式(2):
这两个计算量公式是根据Muti-Head Self-Attention公式得来的(Attention公式介绍可查看博文【论文阅读笔记:Attention Is All You Need】):
两个矩阵相乘(
A
a
×
b
×
B
b
×
c
=
C
a
×
c
A^{a×b}×B^{b×c}=C^{a×c}
Aa×b×Bb×c=Ca×c)的计算量为:
F
L
O
P
s
=
a
×
b
×
c
+
a
×
(
b
−
1
)
×
c
≈
2
×
a
×
b
×
c
FLOPs = a×b×c + a×(b-1)×c≈2×a×b×c
FLOPs=a×b×c+a×(b−1)×c≈2×a×b×c
其中包含
a
×
b
×
c
a×b×c
a×b×c 次乘法和
a
×
(
b
−
1
)
×
c
a×(b-1)×c
a×(b−1)×c 次加法。
MSA计算步骤如下:
1.由于
h
×
w
×
C
h×w×C
h×w×C 的 特征图相当于有
h
×
w
h×w
h×w 个
C
C
C 维的token向量,将其表示为矩阵:
A
h
w
×
C
A^{hw×C}
Ahw×C
2.Token 矩阵通过乘上 W q C × d k , W k C × d k , W v C × d v W_q^{C×d_k},W_k^{C×d_k},W_v^{C×d_v} WqC×dk,WkC×dk,WvC×dv 获得对应的 Q h w × d k , K h w × d k , V h w × d v Q^{hw×d_k},K^{hw×d_k},V^{hw×d_v} Qhw×dk,Khw×dk,Vhw×dv,即:
依据Vision transformer的源码,假设 d k = d v = C h e a d d_k=d_v= \frac{C}{head} dk=dv=headC,head 为 MutiHead 中的 head 个数,因此 3 对矩阵相乘计算量为 6 h w C 2 h e a d \frac{6hwC^2}{head} head6hwC2。
3.接着计算 Q h w × C h e a d × ( K T ) C h e a d × h w Q^{hw× \frac{C}{head}}×(K^T)^{ \frac{C}{head}×hw} Qhw×headC×(KT)headC×hw,得到 h w × h w hw× hw hw×hw 大小的矩阵,计算量为 2 ( h w ) 2 C h e a d \frac{2(hw)^2C}{head} head2(hw)2C。
4.除以
d
k
\sqrt{d_k}
dk再计算softmax,论文中提出忽略这部分的计算量:
矩阵大小依然为
h
w
×
h
w
hw× hw
hw×hw 。
5.将得到的 h w × h w hw×hw hw×hw 的矩阵再乘上 V h w × C h e a d V^{hw× \frac{C}{head}} Vhw×headC,得到 h w × C h e a d hw× \frac{C}{head} hw×headC 大小的矩阵。计算量为 2 ( h w ) 2 C h e a d \frac{2(hw)^2C}{head} head2(hw)2C。
6.最后再乘上融合矩阵 W O C h e a d ∗ C W_O^{ \frac{C}{head}*C} WOheadC∗C将特征矩阵还原成 h w × C hw×C hw×C 大小的矩阵,计算量为 2 h w C 2 h e a d \frac{2hwC^2}{head} head2hwC2。
总计算量为 8 h w C 2 h e a d + 4 ( h w ) 2 C h e a d \frac{8hwC^2}{head}+\frac{4(hw)^2C}{head} head8hwC2+head4(hw)2C,因为 muti-head self-attention 中 head ≥ 2,所以总计算量 ≤ 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2+2(hw)^2C 4hwC2+2(hw)2C,这里取最大值进行比较。
W-MSA计算步骤如下:
1.将
h
×
w
×
C
h×w×C
h×w×C 的 特征图划分到
h
M
×
w
M
\frac{h}{M}×\frac{w}{M}
Mh×Mw 个宽高均为 M 的 windows 中。
2.将每个宽高为 M 的 windows 进行WSA计算,每个 windows 的计算量为 4 ( M C ) 2 + 2 ( M ) 4 C 4(MC)^2+2(M)^4C 4(MC)2+2(M)4C。
3. h M × w M \frac{h}{M}×\frac{w}{M} Mh×Mw 个 windows 的总计算量为 h M × w M × ( 4 ( M C ) 2 + 2 ( M ) 4 C ) = 4 h w C 2 + 2 M 2 h w C \frac{h}{M}×\frac{w}{M}×(4(MC)^2+2(M)^4C) =4hwC^2+2M^2hwC Mh×Mw×(4(MC)2+2(M)4C)=4hwC2+2M2hwC.
由于W-MSA只和 h w hw hw 的一次方成线性关系,而MSA会包含 h w hw hw 的二次关系,因此W-MSA大大的降低了计算量。
SW-MSA
前文介绍W-MSA在划分窗口时会带来windows间信息无法交互的问题,所以作者提到使用 shifted window 的方法来增加信息交互,论文的描述如下:
即先从左上角开始使用常规的W-MSA进行窗口划分,每个窗口大小为
M
×
M
M×M
M×M,接着以
s
t
r
i
d
e
=
(
⌊
M
2
⌋
,
⌊
M
2
⌋
)
stride = (\lfloor \frac{M}{2} \rfloor,\lfloor \frac{M}{2} \rfloor)
stride=(⌊2M⌋,⌊2M⌋) 滑动窗口。
以上图为例,左边为
L
a
y
e
r
Layer
Layer
l
l
l 层 feature map 大小为
8
×
8
8×8
8×8,window 大小为
4
×
4
4×4
4×4,共
2
×
2
2×2
2×2 个windows。每个window以
s
t
r
i
d
e
=
(
⌊
M
2
⌋
,
⌊
M
2
⌋
)
=
(
2
,
2
)
stride = (\lfloor \frac{M}{2} \rfloor,\lfloor \frac{M}{2} \rfloor)=(2,2)
stride=(⌊2M⌋,⌊2M⌋)=(2,2) 滑动,得到右边
L
a
y
e
r
Layer
Layer
l
+
1
l+1
l+1 层的 feature map。其过程可以用下图表示:
Efficient batch computation for shifted configuration
SW-MSA 将原本信息不交互的
⌈
h
M
⌉
×
⌈
w
M
⌉
\lceil \frac{h}{M} \rceil×\lceil \frac{w}{M} \rceil
⌈Mh⌉×⌈Mw⌉ 个 windows 做了信息交互并变成了
(
⌈
h
M
⌉
+
1
)
×
(
⌈
w
M
⌉
+
1
)
(\lceil \frac{h}{M} \rceil+1)×(\lceil \frac{w}{M} \rceil+1)
(⌈Mh⌉+1)×(⌈Mw⌉+1) 个,如果对每个 windows 都做 MSA 计算,那么计算量又会比 W-MSA 多,而且 每个 window 的大小也不一样, 无法并行计算。为了解决这个问题,论文提出了一种相邻非重叠窗口之间的连接方式j,如下图:
该过程如果不理解可以看下图:
其将A、B、C三个框中的 window 移动到四个
4
×
4
4×4
4×4 红色框的对应位置,使其凑成四个
4
×
4
4×4
4×4 的window。由于有几个 window 是由不相邻的子窗口组成,需要通过Masked MSA
掩膜计算来限制每个 window 中的不同子窗口的 MSA。
至于 window 是如何移动以及掩膜计算如何实现,请看博文:Swin Transformer代码阅读注释。
Relative Position Bias
论文使用了相对位置偏置,公式如下:
论文中只是提到使用改偏置的效果,并没有产出细节,细节请看博文:Swin Transformer代码阅读注释。
使用了相对位置偏置
(
r
e
l
.
p
o
s
.
)
(rel.pos.)
(rel.pos.)以后带来了明显的提升。