目录
前言
Swin-Transformer是微软亚洲研究院发表于ICCV 2021的一篇论文,并获得了当年的最佳论文。对比 ViT,Swin-Transformer大大降低了计算量,提供了更加通用的基于Transformer的计算机视觉任务的主干网络,并且能应用到分类、检测、分割等多种计算机视觉任务中。
1. 模型的特点
Swin Transformer与Vision Transformer对比有两个突出特征:
- 首先Swin Transformer所构建的feature map是具有层次性的,与卷积网络类似,随着特征层的不断加深,feature map的高和宽越来越小。由于像CNN一样能够构建出具有层次性的特征图,Swin Transformer对于目标检测和分割任务都有更大的优势。ViT则一直保持16倍的下采样,没办法构建出具有层次性的特征图。
- Swin Transformer是用一个个窗口的形式将feature map分割开了,窗口与窗口之间没有重叠。而在ViT中feature map是一个整体,没有对其进行分割。窗口的划分使得Swin Transformer能够在每个窗口内部进行多头自注意力的计算,窗口之间不去进行信息传递。好处在于能够大大降低运算量,尤其是在下采样率较低的浅层网络,相比于直接对整个特征图进行多头自注意力的计算,大大降低了运算量。
2. 模型结构
假设输入一张高H宽W的3通道彩色图像,首先通过Patch Partition模块输出
H
4
×
W
4
×
48
\frac{H}{4} \times \frac{W}{4} \times 48
4H×4W×48的特征图,之后与ResNet网络类似,通过不同的Stage对特征图进行下采样,并且每次下采样后Channel数就会翻倍。需要注意Stage1模块与Stage2、3、4不同之处在于其第一个模块是一个Linear Embeding层,而2、3、4第一个模块是Patch Merging层。实际上,Patch Partition + Linear Embeding的功能与Patch Merging差不多。
2.1 Patch Partition + Linear Embedding
例如输入一张图像,Patch Partition会用 4 × 4 4 \times 4 4×4大小的窗口对其进行分割。分割之后对每一个窗口在channel方向进行展平,也就是对每个像素沿深度方向进行拼接。由于每个像素都是RGB三个通道的,则 16 × 3 = 48 16 \times 3 = 48 16×3=48,所以通过Patch Partition之后,图像的高和宽就缩减为原来的 1 4 \frac{1}{4} 41,通道数变为48。
接下来再通过Linear Embedding层对输入特征矩阵进行调整,输出通道数变为
C
C
C,
C
C
C的大小与模型的选择有关。注意,在Linear Embedding中还包括了一个Layer Norm。在实际实现中,Patch Partition和Linear Embedding的操作都是通过卷积来完成的。
2.2 Patch Merging
Patch Merging的作用是进行下采样,特征图的高和宽会缩减为原来的一半,并且通道数会翻倍。假设输入特征矩阵的高和宽都是
4
×
4
4 \times 4
4×4, 输入通道数为1。以
2
×
2
2 \times 2
2×2大小作为一个窗口,将每个窗口中相同位置的像素取出来,就能得到4个特征矩阵。将这4个特征矩阵在深度方向进行拼接,然后在深度方向进行LayerNorm的处理,最后再通过一个全连接层,对每一个像素的深度方向进行线性映射,输出通道数为2。
2.3 Swin-Transformer Block
对于每个Stage还会重复堆叠每个Swin Transformer Block,注意重复次数均为偶数。**为什么都要重复偶数次呢?**结构如下:
在第一个Block中可以看到,是将ViT中的Multi-head Self-Attention模块替换成了W-MSA (Windows Multi-head Self-Attention) 。第二个Block中则是使用了SW-MSA (Shifted Windows Multi-head Self-Attention) 。 这两个模块都是成对去使用的。
对于Swin Transformer模型的架构当中,其实后面还有一些层结构,比如对于分类网络而言,在Stage 4后面还会接上Layer Norm、全局池化以及一个全连接层进行一个最终输出。
2.4 W-MSA
回顾Multi-head Self-Attention,会对每一个像素求Q、K、V,对于每一个像素所求得的Q会和特征图中的每一个K进行匹配,然后再进行一系列的操作。
对于Swin-Transformer中所提出的Windows Multi-head Self-Attention模块,会对特征图分成一个一个Window,对每一个Window的内部进行Multi-head Self-Attention,但是,Window和Window之间没有任何通信。
目的:减少计算量。
缺点:窗口之间无法进行信息交互。
理论计算量:
Ω
(
M
S
A
)
=
4
h
w
C
2
+
2
(
h
w
)
2
C
\Omega(MSA) = 4hwC^2 + 2(hw)^2C
Ω(MSA)=4hwC2+2(hw)2C
Ω
(
W
−
M
S
A
)
=
4
h
w
C
2
+
2
M
2
h
w
C
\Omega(W-MSA) = 4hwC^2 + 2M^2hwC
Ω(W−MSA)=4hwC2+2M2hwC
- h代表feature map的高度
- w代表feature map的宽度
- C代表feature map的深度
- M代表每个窗口(Windows)的大小
MSA计算量公式是怎么来的?
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
M
a
x
(
Q
K
T
d
)
V
Attention(Q,K,V) = SoftMax(\frac{QK^T}{\sqrt{d}})V
Attention(Q,K,V)=SoftMax(dQKT)V
对于feature map中的每个像素(或称为token,patch),都要通过矩阵
W
q
,
W
k
,
W
v
W_q, W_k, W_v
Wq,Wk,Wv生成对应的query(q),key(k)以及value(v)。这里假设q,k,v的向量长度与feature map的深度C保持一致。那么对应所有像素生成Q的过程如下式:
A
h
w
×
C
⋅
W
q
C
×
C
=
Q
h
w
×
C
A_{hw \times C} \cdot {W^q}_{C \times C} = Q_{hw \times C}
Ahw×C⋅WqC×C=Qhw×C
- A h w × C A_{hw \times C} Ahw×C是所有像素(token)拼接再一起得到的矩阵(一共有 h w hw hw个像素,每个像素深度为 C C C)
- W q C × C {W^q}_{C \times C} WqC×C为生成query的变换矩阵
- Q h w × C Q_{hw \times C} Qhw×C是所有像素与变换矩阵 W q C × C W_q^{C \times C} WqC×C相乘而得到的所有query拼接后的矩阵
根据矩阵乘法运算规则,可以得到生成Q共进行了 h w C 2 hwC^2 hwC2次乘法运算(一共 h w C hwC hwC个像素,每个像素进行C次乘法)。同理,生成K和V都是 h w C 2 hwC^2 hwC2,那么总共 3 h w C 2 3hwC^2 3hwC2次乘法。
接下来
Q
Q
Q和
K
T
K^T
KT相乘,对应计算量为
(
h
w
)
2
C
(hw)^2C
(hw)2C:
Q
h
w
×
C
⋅
K
T
C
×
h
w
=
X
h
w
×
h
w
Q_{hw \times C} \cdot {K^T}_{C \times hw} = X_{hw \times hw}
Qhw×C⋅KTC×hw=Xhw×hw
忽略除以
d
\sqrt{d}
d以及softmax的计算量,假设通过
S
o
f
t
M
a
x
(
Q
K
T
d
)
SoftMax(\frac{QK^T}{\sqrt{d}})
SoftMax(dQKT)得到
Λ
h
w
×
h
w
\Lambda^{hw \times hw}
Λhw×hw,最后还要乘上
V
V
V,对应的计算量是
(
h
w
)
2
C
(hw)^2C
(hw)2C:
Λ
h
w
×
h
w
⋅
V
h
w
×
C
=
B
h
w
×
C
\Lambda^{hw \times hw} \cdot V^{hw \times C} = B^{hw \times C}
Λhw×hw⋅Vhw×C=Bhw×C
综上,单头Self-Attention总共需要
3
h
w
C
2
+
(
h
w
)
2
C
+
(
h
w
)
2
C
=
3
h
w
C
2
+
2
(
h
w
)
2
C
3hwC^2+(hw)^2C+(hw)^2C=3hwC^2+2(hw)^2C
3hwC2+(hw)2C+(hw)2C=3hwC2+2(hw)2C。而对于多头注意力机制,有
M
u
l
t
i
H
e
a
d
(
Q
,
K
,
V
)
=
C
o
n
c
a
t
(
h
e
a
d
1
,
…
,
h
e
a
d
h
)
W
O
MultiHead(Q,K,V) = Concat(head_1,\ldots,head_h)W^O
MultiHead(Q,K,V)=Concat(head1,…,headh)WO
也就是说,多头注意力机制仅是将每个单头自注意力拼接起来乘上了矩阵
W
O
W^O
WO,计算量为:
B
h
w
×
C
⋅
W
O
C
×
C
=
O
h
w
×
C
B^{hw \times C} \cdot W_O^{C \times C}=O^{hw \times C}
Bhw×C⋅WOC×C=Ohw×C
所以总共加起来是:
4
h
w
C
2
+
2
(
h
w
)
2
C
4hwC^2+2(hw)^2C
4hwC2+2(hw)2C
W-MSA的计算公式是怎么来的?
W-MSA是将整个feature map划分为一个个窗口(Windows),假设每个窗口的高和宽都是M,那么总共会得到
h
M
×
w
M
\frac{h}{M} \times \frac{w}{M}
Mh×Mw个窗口,然后再、在每个窗口内使用多头注意力模块。在MSA的计算中,高
h
h
h宽
w
w
w深度为
C
C
C的feature map的计算量为
4
h
w
C
2
+
2
(
h
w
)
2
C
4hwC^2+2(hw)^2C
4hwC2+2(hw)2C,带入
M
×
M
M \times M
M×M有:
4
(
M
C
)
2
+
2
(
M
)
4
C
4(MC)^2+2(M)^4C
4(MC)2+2(M)4C
又因为有
h
M
×
w
M
\frac{h}{M} \times \frac{w}{M}
Mh×Mw个窗口,则:
h
M
×
w
M
×
(
4
(
M
C
)
2
+
2
(
M
)
4
C
)
=
4
h
w
C
2
+
2
M
2
h
w
C
\frac{h}{M} \times \frac{w}{M} \times (4(MC)^2+2(M)^4C)=4hwC^2+2M^2hwC
Mh×Mw×(4(MC)2+2(M)4C)=4hwC2+2M2hwC
故W-MSA模块的计算量为:
4
h
w
C
2
+
2
M
2
h
w
C
4hwC^2+2M^2hwC
4hwC2+2M2hwC
假设feature map的h、w都为112,M=7,C=128,采用W-MSA模块相比MSA模块能够节省约40124743680 FLOPs:
2
(
h
w
)
2
C
−
2
M
2
h
w
C
=
2
×
11
2
4
×
128
−
2
×
7
2
×
11
2
2
×
128
=
40124743680
2(hw)^2C - 2M^2hwC = 2 \times 112^4 \times 128 - 2 \times 7^2 \times 112^2 \times 128 = 40124743680
2(hw)2C−2M2hwC=2×1124×128−2×72×1122×128=40124743680
2.5 SW-MSA
使用W-MSA的问题在于窗口之间没有信息交互,为了解决这个问题作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块。在第l层上采用的是W-MSA模块的话,在第l+1层则要使用SW-MSA模块。在SW-MSA模块中Window的划分有所不同,可以看作将每个Window向右和向下移动了两个像素,融合了不同Window之间的信息。
但是通过Shifted Window之后,如果要并行计算,就需要对不足
4
×
4
4 \times 4
4×4大小的Window进行填充,相当于计算9个Window的注意力,带来了计算量的增加。论文中提出了一种简化运算的方法:
将Window 0标记为区域A,Window 1 和 2标记区域C,Window 3 和 6标记为区域B。
先将A和C移动到下面:
再将A和B移动到右边:
移动之后可以重新划分Window。Window 4不变,将Window 3和5划分到一起,同样将7和1合并在一起,将8、6、2、0合并在一起。这样就在新的4个Window内进行自注意力计算,保持了计算量不变。
但是如果直接简单粗暴地在每个Window内进行计算,又会引入新的问题。例如,对于Window 5 和 Window 3来说,它们本来是两个不相邻的区域,但是现在强行划分在了同一个Window内,直接对它们进行MSA计算是有问题的。所以就希望能够在区域内单独计算Window 5和Window 3的MSA。
原论文中采用了Masked MSA来解决这一问题。假设求得了0位置处的
q
0
q_0
q0,
q
0
q_0
q0要和所有位置的像素进行匹配,就会依次生成
α
0
,
0
,
α
0
,
1
,
…
,
α
0
,
15
\alpha_{0,0},\alpha_{0,1},\ldots,\alpha_{0,15}
α0,0,α0,1,…,α0,15,对应Attention计算公式中
Q
K
T
QK^T
QKT的过程,但是,在计算Window 5内部的MSA时,不希望引入Window 3的信息。源码实现中,将Window 3所对应的
α
0
,
2
,
α
0
,
3
\alpha_{0,2},\alpha_{0,3}
α0,2,α0,3等全部减去100。由于
α
\alpha
α本身数值是很小的,所以在减去100之后就变成了非常大的负数,在经过
S
o
f
t
m
a
x
Softmax
Softmax处理之后,Window 3所对应的
α
\alpha
α就全部近似为0了。
注意:全部计算完成之后,需要将数据还原回初始位置。
2.6 相对位置偏置(Relative position bias)
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
M
a
x
(
Q
K
T
d
+
B
)
V
Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt {d}}+B)V
Attention(Q,K,V)=SoftMax(dQKT+B)V
在计算Attention时,加上了一个相对位置偏置
B
B
B。通过下表可以看出,对比不加入位置编码和加入ViT中的绝对位置编码,加入相对位置偏置后的结果,在分类、目标检测、分割任务中都会表现得更好。
**什么是相对位置偏置?**假设feature map是
2
×
2
2 \times 2
2×2大小的,那对于蓝色的像素,其绝对位置索引就是
(
0
,
0
)
(0,0)
(0,0),第0行,第0列。那么匹配蓝色像素时的相对位置索引,就是用蓝色的绝对位置索引减去相应位置的绝对位置索引。将每一个相对位置索引的矩阵在行方向上进行展平,拼接在一起可以得到一个大的矩阵。根据每个位置的相对位置索引都可以在Relative position bias当中取到一个对应的参数。
在原作者的代码当中使用的并不是一个二维的位置坐标,而是使用了一维坐标。如何将二维转化成一维呢?
首先将偏移从0开始,行、列标上加上
M
−
1
M-1
M−1,
M
M
M对应窗口大小。然后在行标上乘上
2
M
−
1
2M-1
2M−1,最后再将行标和列表相加,得到最终的一维位置矩阵(relative position index)。然后在元素个数为
(
2
M
−
1
)
×
(
2
M
−
1
)
(2M-1) \times (2M-1)
(2M−1)×(2M−1)的relative position bias table中取出对应的值,最终得到relative position bias,即最终使用到的B。