目录
📝论文下载地址
🔨代码下载地址
[GitHub-official-Pytorch]
[GitHub-unofficial-Pytorch]
👨🎓论文作者
📦模型讲解
[背景介绍]
Transformer相关背景见[Transformer]。Transformer在处理计算机视觉任务取得不错的效果,但是始终没能超过卷积神经网络的相关方法。例如,针对图像识别任务的ViT网络,见下图。
另外,针对图像目标检测任务的DETR网络,见下图。
目前大多数的针对CV的Transformer方法都是首先将输入图像拆分为补丁,补丁的处理方式与 NLP 应用程序中相同。然后使用几个自监督层进行全局的信息交流,提取特征进行分类。直到谷歌提出NesT的模型,才使得Transformer方法在CV领域成为SOTA。
[模型解读]
作者提出一种具有层次化视觉的Transformer模型,以下简称为SwinT。
[总体结构]
SwinT首先对图片进行非重叠切片,与ViT相同。每个特征都会拉成一维向量。作者设置patch的大小为 4 × 4 4 \times 4 4×4,对于RGB图像来说,每一个切片都对应 4 × 4 × 3 = 48 4 \times 4 \times 3 = 48 4×4×3=48维度的特征。作者对Transformer进行修改,提出了Swin Transformer。在阶段1,利用Swin Transformer块,特征的维度保持在 H / 4 × W / 4 H/4 \times W/4 H/4×W/4。随着网络加深,使用Patch Merging结构进行特征维度的降低。在阶段2,第一个Patch Merging结构将 2 × 2 2\times 2 2×2的相邻切片进行串联,并在 4 C 4C 4C的通道数上使用线形层进行降维,减小为原始的 2 × 2 = 4 2\times2=4 2×2=4倍,输出维度设置为 2 C 2C 2C。Swin Transformer块进行特征转换,特征维度保持在 H / 8 × W / 8 H/8\times W/8 H/8×W/8。改阶段持续两次,分别为阶段3、阶段4,输出特征维度分别为 H / 16 × W / 16 H/16\times W/16 H/16×W/16、 H / 32 × W / 32 H/32\times W/32 H/32×W/32。
[Swin Transformer模块]
Swin Transformer模块是通过使用基于移位窗口的模块替换 Transformer 模块中的标准MSA模块来构建的,其他层保持不变。如上图右侧所示,一个Swin Transformer模块由一个基于移动窗口的MSA模块组成,然后是一个2层MLP,包含GELU非线性激活函数。在每个MSA模块和每个MLP之前应用一个LayerNorm层,在每个模块之后应用残差直连。
[非重叠窗口的自注意力]
为了使网络运行高效,作者设计在一个本地窗口内进行自注意力的计算。将图片分割为几个不重叠的本地窗口。假设每个窗口包含 M × M M\times M M×M个图像切片,全局MSA模块和基于 h × w h\times w h×w切片图像的窗口计算复杂度为: Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C Ω ( W − M S A ) = 4 h w C 2 + 2 M 2 h w C \Omega(MSA)=4hwC^2+2(hw)^2C\\\Omega(W-MSA)=4hwC^2+2M^2hwC Ω(MSA)=4hwC2+2(hw)2CΩ(W−MSA)=4hwC2+2M2hwC其中前者对补丁数 h w hw hw是二次的,后者在 M M M固定时是线性的(默认设置为7)。
[连续Swin Transformer块中的移位窗口分区]
基于窗口的自注意力模块缺乏跨窗口的连接,这限制了其建模能力。为了在保持非重叠窗口的高效计算的同时引入跨窗口连接,作者提出了一种移位窗口分区方法,该方法在连续的SwinTransformer块中的两个分区配置之间交替。
如图2所示,第一个模块使用标准的划分方法从左上角像素开始,将 8 × 8 8\times 8 8×8的特征图均匀地划分为大小为 4 × 4 ( M = 4 ) 4\times4(M=4) 4×4(M=4)的 2 × 2 2\times2 2×2个切片的窗口。然后,下一个模块采用从前一层的窗口偏移,通过将窗口从规则分区的窗口中移动 ( ⌊ M 2 ⌋ , ⌊ M 2 ⌋ ) (\lfloor\frac{M}{2}\rfloor,\lfloor\frac{M}{2}\rfloor) (⌊2M⌋,⌊2M⌋)个像素。使用移位窗口分区方法,连续的 Swin Transformer 块计算为: z ^ l = W − M S A ( L N ( z l − 1 ) ) + z l − 1 z l = M L P ( L N ( z ^ l ) ) z ^ l + 1 = S W − M S A ( L N ( z l ) ) + z l z l + 1 = M L P ( L N ( z ^ l + 1 ) ) + z ^ l + 1 \hat z^l=W-MSA(LN(z^{l-1}))+z^{l-1} \\ z^l=MLP(LN(\hat z^l)) \\ \hat z^{l+1}=SW-MSA(LN(z^l))+z^l \\ z^{l+1}=MLP(LN(\hat z^{l+1}))+\hat z^{l+1} z^l=W−MSA(LN(zl−1))+zl−1zl=MLP(LN(z^l))z^l+1=SW−MSA(LN(zl))+zlzl+1=MLP(LN(z^l+1))+z^l+1
[移位的高效批量计算]
移位窗口分区的一个问题是它会导致更多的窗口,窗口从
⌈
h
M
⌉
×
⌈
w
M
⌉
\lceil \frac{h}{M} \rceil \times \lceil \frac{w}{M} \rceil
⌈Mh⌉×⌈Mw⌉移位至
(
⌈
h
M
⌉
+
1
)
×
(
⌈
w
M
⌉
+
1
)
(\lceil \frac{h}{M} \rceil+1) \times (\lceil \frac{w}{M} \rceil+1)
(⌈Mh⌉+1)×(⌈Mw⌉+1),得到的窗可能小于
M
×
M
M \times M
M×M。例如下图所示,整图分为
8
×
8
8\times 8
8×8的切片数目,这里
M
=
4
M=4
M=4,一般情况下,整图会分为4份。但是在偏移之后,整图会分为9份。如果每份都会输入Transformer中,网络的计算量将会增大为2.25倍。作者提出使用循环移位的方法实现更高效的计算。首先将整图进行移位,使用torch.roll
进行
M
2
\frac{M}{2}
2M的移位,这样就使得整图分为4份,每个切片都会对应一个id,例如ABC。在Transformer中会加入一个掩码Mask,使得Transformer只会在相同id的patch之间进行计算。
[网络结构]
作者设计了相关的基础模型,称为Swin-B。林外还设计了Swin-T,Swin-S和Swin-L,它们分别是约是基础模型的
0.25
×
0.25\times
0.25×,
0.5
×
0.5\times
0.5×和
2.0
×
2.0\times
2.0×计算复杂度。Swin-T和Swin-S复杂度分别与Reset-50和Resnet-101类似。Window大小设置为
M
=
7
M=7
M=7。对于所有实验,每个head的查询向量维度为
d
=
32
d = 32
d=32,MLP是以
α
=
4
\alpha= 4
α=4扩展。这些模型变体的体系结构参数是:
Swin-T:
C
=
96
,
l
a
y
e
r
n
u
m
b
e
r
s
=
2
,
2
,
6
,
2
C=96,layer\quad numbers={2,2,6,2}
C=96,layernumbers=2,2,6,2
Swin-S:
C
=
96
,
l
a
y
e
r
n
u
m
b
e
r
s
=
2
,
2
,
18
,
2
C=96,layer \quad numbers={2,2,18,2}
C=96,layernumbers=2,2,18,2
Swin-B:
C
=
128
,
l
a
y
e
r
n
u
m
b
e
r
s
=
2
,
2
,
18
,
2
C=128,layer \quad numbers={2,2,18,2}
C=128,layernumbers=2,2,18,2
Swin-L:
C
=
192
,
l
a
y
e
r
n
u
m
b
e
r
s
=
2
,
2
,
18
,
2
C=192,layer \quad numbers={2,2,18,2}
C=192,layernumbers=2,2,18,2
其中C为特征维度。