Swin Transformer
Shifted window Transformer: Transformer 结构的计算机视觉通用网络框架
1.介绍
Transformer从NLP迁移到CV上没有大放异彩主要有两点原因:1. 两个领域涉及的尺度不同,NLP的scale是标准固定的,每个单词的向量长度固定,而CV的scale变化范围非常大,图像中通常有大小不一的目标物。2. CV比起NLP需要更大的分辨率,而且CV中使用Transformer的计算复杂度是图像尺度的平方,这会导致计算量过于庞大。为此,(1)作者提出层次化构建网络,同时复杂度与图像大小呈线性关系。如图1,2,进入transformer结构的patch 分辨率逐渐减小(宽高2倍),但对应的维度增加(2倍),类似cnn的通道增加,特征图减小。相比VIT,图1(b)分辨率不变。将transformer适应不同尺度,同时减少计算量(2)引入 shifted windows。窗格内共享multi-head self-attention 的 key 集。
图1
图2
2.方法
2.1 网络结构
结构如图2 所示,输入图像为(H,W,3),设置patch 大小为(4,4,3),展开即(1,48),划分后输入为 H 4 × W 4 × 48 \frac{H}{4}\times \frac{W}{4} \times 48 4H×4W×48,先经过linear embedding将长度48的patch 映射到任意维度,记为C,此时输出为 H 4 × W 4 × C \frac{H}{4}\times \frac{W}{4} \times C 4H×4W×C,并送往Swin Transformer Block,在stage 1中数据维度不变。在stage 2的patch merging中,将每4个 ( 2 × 2 ) (2\times 2) (2×2)patch 合并,形成4C的数据,映射到2C维度的数据,此时,stage 2的输出维度为 H 8 × W 8 × 2 C \frac{H}{8}\times \frac{W}{8} \times 2C 8H×8W×2C,类似的,stage 3与stage 4 的维度如图2 中所示。然后各stage 重复 多个block。
2.2 Swin Transformer block
Swin Transformer block如图2(b)所示,由两部分组成。
计算公式如下 :
z
^
l
=
W
−
M
S
A
(
L
N
(
z
l
−
1
)
)
+
z
l
−
1
\hat{\mathbf{z}}^{l}=\mathrm{W}-\mathrm{MSA}\left(\mathrm{LN}\left(\mathrm{z}^{l-1}\right)\right)+\mathrm{z}^{l-1}
z^l=W−MSA(LN(zl−1))+zl−1
z
l
=
M
L
P
(
L
N
(
z
^
l
)
)
+
z
^
l
\mathrm{z}^{l}=\mathrm{MLP}\left(\mathrm{LN}\left(\hat{\mathrm{z}}^{l}\right)\right)+\hat{\mathbf{z}}^{l}
zl=MLP(LN(z^l))+z^l
z
^
l
+
1
=
S
W
−
M
S
A
(
L
N
(
z
l
)
)
+
z
l
\hat{\mathbf{z}}^{l+1}=\mathrm{SW}-\mathrm{MSA}\left(\mathrm{LN}\left(\mathrm{z}^{l}\right)\right)+\mathrm{z}^{l}
z^l+1=SW−MSA(LN(zl))+zl
z
l
+
1
=
M
L
P
(
L
N
(
z
^
l
+
1
)
)
+
z
^
l
+
1
\mathrm{z}^{l+1}=\mathrm{MLP}\left(\mathrm{LN}\left(\hat{\mathrm{z}}^{l+1}\right)\right)+\hat{\mathbf{z}}^{l+1}
zl+1=MLP(LN(z^l+1))+z^l+1
2.2.1 MSA
引入可学习的相对位置偏差参数B
Attention
(
Q
,
K
,
V
)
=
SoftMax
(
Q
K
T
/
d
+
B
)
V
(Q, K, V)=\operatorname{SoftMax}\left(Q K^{T} / \sqrt{d}+B\right) V
(Q,K,V)=SoftMax(QKT/d+B)V
2.2.2 W-MSA
窗口划分:按每个window ( 4 × 4 ) (4\times 4) (4×4)大小不重合划分,例如图3 Layer l 被分成 ( 2 × 2 ) (2\times 2) (2×2)个windows, 每个window内使用多头注意力块计算。
图3
2.2.3 SW-MSA
引入shifted window partition来解决不同window的信息交流问题。
将每个窗口移动
(
2
×
2
)
(2\times 2)
(2×2)个patch,形成如图3 Layer l+1中的划分结果,此时增加了大小不一致的小窗口,为此进行循环补充,这样窗口数量保持不变。图4 中A块移至右下,B块移至右边,C块移至下方。
图4
3. 实验
图像分类
目标检测
表2(a)将不同方法分别使用Swin与Resnet为骨干网络测试性能。
表2(b)为Cascade Mask RCNN下使用不同网络的检测性。©为整体性能比较。
语义分割
不同算法使用不同骨干网络的分割性能比较
消融实验
有无shifted windows,有无位置编码。