Content
Abstract
- propose Shifted Window (Swin) self-attention
- introduce Relative Position Encoding (RPE)
(a) The proposed Swin Transformer builds hierarchical feature maps by merging image patches (shown in gray) in deeper layers and has linear computation complexity to input image size due to computation of self-attention only within each local window (shown in red). It can thus serve as a general-purpose backbone for both image classification and dense recognition tasks. (b) In contrast, previous vision Transformers produce feature maps of a single low resolution and have quadratic computation complexity to input image size due to computation of self-attention globally.
Method
model architecture
(a) The architecture of a Swin Transformer (Swin-T); (b) two successive Swin Transformer Blocks. W-MSA and SW-MSA are multi-head self attention modules with regular and shifted windowing configurations, respectively.
patch partition: split input into
4
×
4
4\times4
4×4 non-overlapping patches
stage 1
linear embed: project features to dim=C
R
B
×
3
×
H
×
W
→
R
B
×
(
H
P
×
W
P
)
×
C
,
P
=
4
,
C
=
96
{\Reals}^{B\times3\times H\times W}\rightarrow{\Reals}^{B\times(\frac{H}P\times\frac{W}P)\times C}, P=4, C=96
RB×3×H×W→RB×(PH×PW)×C,P=4,C=96
SwinTransformerBlock
stage 2
patch merge: concatenate features (dim=4C) of
2
×
2
2\times2
2×2 neighboring patches
R
B
×
(
h
×
w
)
×
C
→
R
B
×
(
h
2
×
w
2
)
×
4
C
→
l
i
n
e
a
r
R
B
×
(
h
2
×
w
2
)
×
2
C
{\Reals}^{B\times(h\times w)\times C}\rightarrow{\Reals}^{B\times(\frac{h}2\times\frac{w}2)\times4C}\xrightarrow{linear}{\Reals}^{B\times(\frac{h}2\times\frac{w}2)\times2C}
RB×(h×w)×C→RB×(2h×2w)×4ClinearRB×(2h×2w)×2C
SwinTransformerBlock
stage 3, 4 similar to stage 2, only different in patches size
in implementation, block design is different from figure in code
input
→
\rightarrow
→ PatchEmbed: patch partition, linear embed
→
\rightarrow
→ 3 BasicLayer: SwinTransformerBlock, PatchMerge
→
\rightarrow
→ output
shifted window (Swin) attention
An illustration of the shifted window approach for computing self-attention in the proposed Swin Transformer architecture. In layer ℓ \ell ℓ (left), a regular window partitioning scheme is adopted, and self-attention is computed within each window. In the next layer ℓ + 1 \ell+1 ℓ+1 (right), the window partitioning is shifted, resulting in new windows. The self-attention computation in the new windows crosses the boundaries of the previous windows in layer ℓ \ell ℓ, providing connections among them.
feature map
x
∈
R
h
×
w
×
C
x\in{\Reals}^{h\times w\times C}
x∈Rh×w×C divided into
w
s
×
w
s
ws\times ws
ws×ws-size windows
self-attention applied on each sub-windows who contain
w
s
×
w
s
ws\times ws
ws×ws tokens, respectively
WMSA lack connections across windows, shifted WMSA solve this problem
efficient batch computation for shifted window
problems of shifted window partition
- more windows: ⌈ h w s ⌉ × ⌈ w w s ⌉ → ( ⌈ h w s ⌉ + 1 ) × ( ⌈ w w s ⌉ + 1 ) \lceil\frac{h}{ws}\rceil\times\lceil\frac{w}{ws}\rceil\rightarrow(\lceil\frac{h}{ws}\rceil+1)\times(\lceil\frac{w}{ws}\rceil+1) ⌈wsh⌉×⌈wsw⌉→(⌈wsh⌉+1)×(⌈wsw⌉+1)
- some windows size smaller than w s × w s ws\times ws ws×ws
Illustration of an efficient batch computation approach for self-attention in shifted window partitioning.
solution cyclic-shift patches towards top-left direction
step 1: shift
x
x
x towards top-left direction, with shift-size (
=
1
2
w
s
=\frac12ws
=21ws)
step 2: window partition on shifted
x
x
x
step 3: masked window attention on
x
x
x windows, with a prepared top-left shifted mask
step 4: reverse cyclic shift on attended results (bottom-right)
computational complexity
in ViT, given input feature map
x
∈
R
h
×
w
×
C
x\in{\Reals}^{h\times w\times C}
x∈Rh×w×C, FLOPs of MSA is
Ω
(
M
S
A
)
=
4
h
w
C
2
+
2
(
h
w
)
2
C
\Omega(\mathrm{MSA})=4hwC^2+2(hw)^2C
Ω(MSA)=4hwC2+2(hw)2C
for (Shifted-)WMSA, replace
h
×
w
h\times w
h×w with window size
w
s
×
w
s
ws\times ws
ws×ws and batchsize
×
n
w
i
n
d
o
w
s
=
h
w
(
w
s
)
2
\times n_{windows}=\frac{hw}{(ws)^2}
×nwindows=(ws)2hw
Ω
(
(
S
h
i
f
t
e
d
-
)
W
M
S
A
)
=
4
h
w
C
2
+
2
h
w
(
w
s
)
2
C
\Omega(\mathrm{(Shifted\text{-})WMSA})=4hwC^2+2hw(ws)^2C
Ω((Shifted-)WMSA)=4hwC2+2hw(ws)2C
relative positional encoding (RPE)
include a relative position bias
B
∈
R
w
s
2
×
w
s
2
B\in{\Reals}^{ws^2\times ws^2}
B∈Rws2×ws2 to each head in computing similarity
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
+
B
)
V
\mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d}}+B)V
Attention(Q,K,V)=softmax(dQKT+B)V
where,
Swin transformer encoder
with consecutive encoder blocks alternating between (Shifted-)WMSA, transformer encoder computed as
z
^
ℓ
=
W
M
S
A
(
L
N
(
z
ℓ
−
1
)
)
+
z
ℓ
−
1
z
ℓ
=
F
F
N
(
L
N
(
z
^
ℓ
)
)
+
z
^
ℓ
z
^
ℓ
+
1
=
S
h
i
f
t
e
d
-
W
M
S
A
(
L
N
(
z
ℓ
)
)
+
z
ℓ
z
ℓ
+
1
=
F
F
N
(
L
N
(
z
^
ℓ
+
1
)
)
+
z
^
ℓ
+
1
\begin{aligned} \widehat{z}_{\ell}&=\mathrm{WMSA}(\mathrm{LN}(z_{\ell-1}))+z_{\ell-1} \\ z_{\ell}&=\mathrm{FFN}(\mathrm{LN}(\widehat{z}_{\ell}))+\widehat{z}_{\ell} \\ \widehat{z}_{\ell+1}&=\mathrm{Shifted\text{-}WMSA}(LN(z_{\ell}))+z_{\ell} \\ z_{\ell+1}&=\mathrm{FFN}(\mathrm{LN}(\widehat{z}_{\ell+1}))+\widehat{z}_{\ell+1} \end{aligned}
z
ℓzℓz
ℓ+1zℓ+1=WMSA(LN(zℓ−1))+zℓ−1=FFN(LN(z
ℓ))+z
ℓ=Shifted-WMSA(LN(zℓ))+zℓ=FFN(LN(z
ℓ+1))+z
ℓ+1
architecture variants
window size:
M
=
7
M=7
M=7
query dimension of each head:
d
=
32
d=32
d=32
expansion layer of each MLP:
α
=
4
\alpha=4
α=4
architecture hyper-parameters of model variants
- Swin-T: C = 96, layer numbers = {2, 2, 6, 2}
- Swin-S: C = 96, layer numbers = {2, 2, 18, 2}
- Swin-B: C = 128, layer numbers = {2, 2, 18, 2}
where, C C C is channel number of hidden layers in the first stage
Detailed architecture specifications.
Experiment
image classification
dataset Image-1K
Comparison of different backbones on ImageNet-1K classification.
object detection and instance segmentation
framework Cascade Mask R-CNN, ATSS, RepPoints v2, Sparse RCNN
dataset COCO 2017
Results on COCO object detection and instance segmentation. “ † \dag †” denotes that additional deconvolution layers are used to produce hierarchical feature maps. “ ∗ \ast ∗” indicates multi-scale testing.
semantic segmentation
framework
dataset ADE20K
Results of semantic segmentation on the ADE20K val and test set. “ † \dag †” indicates additional deconvolution layers are used to produce hierarchical feature maps. “ ‡ \ddag ‡” indicates that the model is pre-trained on ImageNet-22K.