Content
Contribution
- propose Cross-Shape Window (CSWin) self-attention
- propose Locally-enhanced Position Encoding (LePE)
Method
model architecture
Left: the overall hierarchical architecture of our proposed CSWin Transformer, Right: the illustration of our proposed CSWin Transformer block.
cross-shape window (CSWin) attention
tokens within Transformer blocks limit attention area and require stacking more blocks to achieve global receptive field
solution apply halo (HaloNet) or shifted window (Swin) to enlarge receptive field
an efficient way cross-shaped window self-attention with horizontal and vertical stripes in parallel
(a) Left: the illustration of the Cross-Shaped Window (CSWin) with stripe width sw for the query point(red dot). Right: the computing of CSWin self-attention, where multi-heads ({ h 1 , . . . , h K h_1, ..., h_K h1,...,hK}) is first split into two groups, then two groups of heads perform self-attention in horizontal and vertical stripes respectively, and finally are concatenated together. (b), ©, (d), and (e) are existing self-attention mechanisms.
given input feature
x
∈
R
h
×
w
×
C
x\in R^{h\times w\times C}
x∈Rh×w×C linearly projected to K heads, which equally split into 2 parallel groups
each head in 2 groups perform locally self-attention within either horizontal or vertical stripes
x evenly partitioned into horizontal stripes, each with
s
w
×
W
sw\times W
sw×W tokens
x
=
[
x
1
,
x
2
,
.
.
.
,
x
m
]
x=[x_1, x_2, ..., x_m]
x=[x1,x2,...,xm]
where,
x
i
∈
R
s
w
×
h
×
C
x_i\in R^{sw\times h\times C}
xi∈Rsw×h×C,
m
=
h
s
w
m=\frac h{sw}
m=swh
calculate self-attention for each k-th head
y
k
i
=
W
M
S
A
(
x
i
W
k
Q
,
x
i
W
k
K
,
x
i
W
k
V
)
,
i
=
1
,
.
.
.
,
m
{y_k}^i=WMSA(x_i{W_k}^Q, x_i{W_k}^K, x_i{W_k}^V), i=1, ..., m
yki=WMSA(xiWkQ,xiWkK,xiWkV),i=1,...,m
H
C
S
−
W
M
S
A
k
(
X
)
=
[
y
k
1
,
y
k
2
,
.
.
.
,
y
k
M
]
{HCS-WMSA}_k(X)=[{y_k}^1, {y_k}^2, ..., {y_k}^M]
HCS−WMSAk(X)=[yk1,yk2,...,ykM]
where,
W
k
∈
R
C
×
C
W_k\in R^{C\times C}
Wk∈RC×C is projection matrix that project self-attention results into target output dimension
similarily, for vertical stripes, attention denoted as
V
C
S
−
W
M
S
A
k
(
x
)
{VCS-WMSA}_k(x)
VCS−WMSAk(x)
concat horizontal and vertical attention output together
C
S
−
W
M
S
A
(
x
)
=
c
o
n
c
a
t
(
h
e
a
d
1
,
.
.
.
,
h
e
a
d
k
)
W
CS-WMSA(x)=concat(head_1, ..., head_k)W
CS−WMSA(x)=concat(head1,...,headk)W
where,
h
e
a
d
k
=
{
H
C
S
−
W
M
S
A
k
(
x
)
,
k
=
1
,
.
.
.
,
K
2
V
C
S
−
W
M
S
A
k
(
x
)
,
k
=
k
2
+
1
,
.
.
.
,
K
head_k=\left\{\begin{aligned}{HCS-WMSA}_k(x)&, k=1, ..., \frac K2\\{VCS-WMSA}_k(x)&, k=\frac k2+1, ..., K\end{aligned}\right.
headk=⎩⎪⎪⎨⎪⎪⎧HCS−WMSAk(x)VCS−WMSAk(x),k=1,...,2K,k=2k+1,...,K
adjusted sw small sw for early stages, larger sw for later stages
for HR inputs, h w larger than C in early stages and smaller than C in later stages
- in early stages (h w larger), smaller sw reduce computation, for local attention
- in later stages (h w smaller), larger sw enlarge receptive field, for global attention
CSWin transformer encoder
with encoder blocks containing cross-shaped-WMSA, transformer encoder computed as
z
^
l
=
C
S
−
W
M
S
A
(
L
N
(
z
l
−
1
)
)
+
z
l
−
1
z
l
=
F
F
N
(
L
N
(
z
^
l
)
)
+
z
^
l
\begin{aligned} \widehat{z}_l&=CS-WMSA(LN(z_{l-1}))+z_{l-1} \\ z_l&=FFN(LN(\widehat{z}_l))+\widehat{z}_l \end{aligned}
z
lzl=CS−WMSA(LN(zl−1))+zl−1=FFN(LN(z
l))+z
l
locally-enhanced positional encoding (LePE)
Comparison among different positional encoding mechanisms: APE and CPE introduce the positional information before feeding into the Transformer blocks, while RPE and our LePE operate in each Transformer block. Different from RPE that adds the positional information into the attention calculation, our LePE operates directly upon V and acts as a parallel module. Here we only draw the self-attention part to represent the Transformer block for simplicity.
APE/CPE add positional information before transformer blocks
RPE add positional information within attention calculation
LePE impose positional information upon linearly projected values
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
)
V
+
E
V
Attention(Q, K, V)=softmax(\frac {QK^T}{\sqrt{d}})V+EV
Attention(Q,K,V)=softmax(dQKT)V+EV
if all connections in E considered, a huge computation cost required, supposed the most vital positional information is from neighborhood of input
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
)
V
+
D
W
C
o
n
v
(
V
)
Attention(Q, K, V)=softmax(\frac {QK^T}{\sqrt{d}})V+DWConv(V)
Attention(Q,K,V)=softmax(dQKT)V+DWConv(V)
where, LePE implemented by depth-wise conv: group conv 3x3, groups=embed_dim
computational complexity
in ViT, given input feature map
x
∈
R
h
×
w
×
C
x\in R^{h\times w\times C}
x∈Rh×w×C, FLOPs of MSA is
Ω
(
M
S
A
)
=
4
h
w
C
2
+
2
(
h
w
)
2
C
\Omega(MSA)=4hwC^2+2(hw)^2C
Ω(MSA)=4hwC2+2(hw)2C
for
1
2
n
h
e
a
d
s
\frac 12 n_{heads}
21nheads horizontal stripes, replace w with sw
Ω
(
h
)
=
2
w
(
s
w
)
C
2
+
(
w
×
s
w
)
2
C
\Omega(h)=2w(sw)C^2+(w\times sw)^2C
Ω(h)=2w(sw)C2+(w×sw)2C
for
1
2
n
h
e
a
d
s
\frac 12 n_{heads}
21nheads vertical stripes, replace h with sw
Ω
(
w
)
=
2
h
(
s
w
)
C
2
+
(
h
×
s
w
)
2
C
\Omega(w)=2h(sw)C^2+(h\times sw)^2C
Ω(w)=2h(sw)C2+(h×sw)2C
for CS-WMSA, batchsize
×
n
s
t
r
i
p
s
=
w
w
s
\times n_{strips}=\frac w{ws}
×nstrips=wsw or
h
w
s
\frac h{ws}
wsh
Ω
(
C
S
−
W
M
S
A
)
=
Ω
(
h
)
×
w
s
w
+
Ω
(
w
)
×
h
s
w
=
4
h
w
C
2
+
s
w
(
h
+
w
)
h
w
C
\Omega(CS-WMSA)=\Omega(h)\times \frac w{sw}+\Omega(w)\times \frac h{sw}=4hwC^2+sw(h+w)hwC
Ω(CS−WMSA)=Ω(h)×sww+Ω(w)×swh=4hwC2+sw(h+w)hwC
architecture variants
Detailed configurations of different variants of CSWin Transformer. Note that the FLOPs are calculated with 224x224 input.
Experiment
image classification
dataset ImageNet-1K, with augmentation as DeiT
optimizer AdamW: batchsize=1024, 300 epochs, init lr=1e-3, weigh decay=0.05 or 0.1, linear warm-up 20 epochs, cosine decay
stochastic depth 0.1, 0.3, 0.5 for CSWin-T, CSWin-S, CSWin-B
Comparison of different models on ImageNet-1K classification. “*” means the EfficientNet are trained with other input sizes. Here the models are grouped based on the computation complexity.
- pre-training
dataset ImageNet-21K
optimizer AdamW: batchsize=2048, 90 epochs, init lr=1e-3, weigh decay=0.1 or 0.2 - fine-tuning
dataset ImageNet-1K
optimizer AdamW: batchsize=512, 30 epochs, lr=1e-5, weigh decay=1e-8
stochastic depth 0.1 for both CSWin-B, CSWin-L
ImageNet-1K fine-tuning results by pre-training on ImageNet-21K datasets.
object detection and instance segmentation
framework Mask R-CNN
dataset COCO
- 1x schedule
optimizer AdamW: batchsize=16, 12 epochs, init lr=1e-4, weigh decay=0.05, decay rate=0.1 at 8, 11-th epoch - 3x-MS schedule
optimizer AdamW: batchsize=16, 36 epochs, init lr=1e-4, weigh decay=0.05, decay rate=0.1 at 27, 33-th epoch
Object detection and instance segmentation performance on the COCO val2017 with the Mask R-CNN framework. The FLOPs (G) are measured at resolution 800x1280, and the models are pre-trained on the ImageNet-1K dataset.
framework Cascade Mask R-CNN
dataset COCO
optimizer AdamW: batchsize=16, 36 epochs, init lr=1e-4, weigh decay=0.05, decay rate=0.1 at 27, 33-th epoch
Object detection and instance segmentation performance on the COCO val2017 with Cascade Mask R-CNN.
semantic segmentation
framework Semantic FPN
dataset ADE20K
optimizer AdamW: batchsize=16, 80K iterations, init lr=1e-4, weight decay=1e-4
framework UPerNet
dataset ADE20K
optimizer AdamW: batchsize=16, 160K iterations, init lr=6e-5, weigh decay=5e-4, linear warm-up 1500 iterations, linear decay
stochastic depth 0.1, 0.3, 0.5 for CSWin-T, CSWin-S, CSWin-B
Performance comparison of different backbones on the ADE20K segmentation task. Two different frameworks semantic FPN and Upernet are used. FLOPs are calculated with resolution 512x2048. “+” means the model is pretrained on ImageNet-21K and finetuned with 640x640 resolution.
ablation study
component analysis
Ablation study of each component to better understand CSWin Transformer. “SA”, “Arch”,“CTE” denote “Self-Attention”, “Architecture”, and “Convolutional Token Embedding” respectively.
- evaluate baseline that fixes sw=1 for the first three stages and observe dramatic performance drop
indicate that adjusting sw to enlarge attention area is very crucial - change parallel self-attention design into sequential counterpart without multi-heads grouping and find performance drop
indicate that multi-heads grouping is effective
self-attention mechanism
shallow-wide design used in above subsection: 2, 2, 6, 2 blocks for four stages, base channel is 96
apply non-overlapped token embedding and RPE in above models
Ablation study of different self-attention mechanisms and positional encoding mechanisms. “*” denotes applying CPE before every Transformer block.
positional encoding
positional encoding bring performance gain by introducing local inductive bias
LePE perform better on downstream tasks where input resolution varies
stripes width
vary [
s
w
1
sw_1
sw1,
s
w
2
sw_2
sw2,
s
w
3
sw_3
sw3] of the first three stages of CSWin-T and keep the last stage with
s
w
4
=
7
sw_4=7
sw4=7
Ablation study on different stripes width. We show the sw of each stage with the form [ s w 1 sw_1 sw1, s w 2 sw_2 sw2, s w 3 sw_3 sw3, s w 4 sw_4 sw4] beside each point and X axis is its corresponding Flops.
with increase of sw, FLOPs increase and accuracy improve greatly at the beginning and slow down when [
s
w
1
sw_1
sw1,
s
w
2
sw_2
sw2,
s
w
3
sw_3
sw3] are large enough
default setting [1, 2, 7, 7] for [
s
w
1
sw_1
sw1,
s
w
2
sw_2
sw2,
s
w
3
sw_3
sw3,
s
w
4
sw_4
sw4] achieve a better trade-off for accuracy and computation cost