Content
Abstract
- propose Separable Vision Transformer (SepViT) with depth-wise separable self-attention
local information interaction within windows
global information exchange among windows - propose window token embedding
learn global feature representations of each window
model attention relationship among windows with negligible computational cost - extend depth-wise separable self-attention to grouped self-attention
capture more contextual concepts across multiple windows
Comparison of throughput and latency on ImageNet-1K classification. The throughput and the latency are tested based on the PyTorch framework with a V100 GPU and TensorRT framework with a T4 GPU, respectively.
Method
model architecture
Separable Vision Transformer (SepViT). The top row is the overall hierarchical architecture of SepViT. The bottom row is the SepViT block and the detailed visualization of our depth-wise separable self-attention and the window token embedding scheme.
depth-wise separable self-attention (DSSA)
depth-wise self-attention (DWA)
depth-wise convolution: fuse spatial information within each channel
depth-wise attention: fuse spatial information within each window
step 1 partition input features
z
ℓ
−
1
z^{\ell-1}
zℓ−1 into windows
z
ℓ
−
1
∈
R
B
×
C
×
H
×
W
→
p
a
r
t
i
t
i
o
n
z
ℓ
−
1
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
C
×
N
z^{\ell-1}\in\Reals^{B\times C\times H\times W}\xrightarrow{partition}z^{\ell-1}\in\Reals^{(n_{windows}\times B)\times C\times N}
zℓ−1∈RB×C×H×Wpartitionzℓ−1∈R(nwindows×B)×C×N
where,
n
w
i
n
d
o
w
s
=
H
H
w
i
n
d
o
w
×
W
W
w
i
n
d
o
w
,
N
=
H
w
i
n
d
o
w
×
W
w
i
n
d
o
w
n_{windows}=\frac{H}{H_{window}}\times\frac{W}{W_{window}}, N=H_{window}\times W_{window}
nwindows=HwindowH×WwindowW,N=Hwindow×Wwindow
step 2 concatenate window tokens
w
t
wt
wt and windowed features
w
t
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
C
×
1
+
z
ℓ
−
1
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
C
×
N
→
c
o
n
c
a
t
z
~
ℓ
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
C
×
(
N
+
1
)
wt\in\Reals^{(n_{windows}\times B)\times C\times1}+z^{\ell-1}\in\Reals^{(n_{windows}\times B)\times C\times N}\xrightarrow{concat}\tilde{z}^{\ell}\in\Reals^{(n_{windows}\times B)\times C\times(N+1)}
wt∈R(nwindows×B)×C×1+zℓ−1∈R(nwindows×B)×C×Nconcatz~ℓ∈R(nwindows×B)×C×(N+1)
step 3 project features into query, key, value
z
~
ℓ
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
C
×
(
N
+
1
)
→
l
i
n
e
a
r
q
k
v
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
3
C
×
(
N
+
1
)
→
s
p
l
i
t
q
,
k
,
v
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
C
×
(
N
+
1
)
\tilde{z}^{\ell}\in\Reals^{(n_{windows}\times B)\times C\times(N+1)}\xrightarrow{linear}qkv\in\Reals^{(n_{windows}\times B)\times3C\times(N+1)}\xrightarrow{split}q, k, v\in\Reals^{(n_{windows}\times B)\times C\times(N+1)}
z~ℓ∈R(nwindows×B)×C×(N+1)linearqkv∈R(nwindows×B)×3C×(N+1)splitq,k,v∈R(nwindows×B)×C×(N+1)
step 4 split query, key, value into multi-head version
q
,
k
,
v
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
C
×
(
N
+
1
)
→
s
p
l
i
t
q
,
k
,
v
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
n
h
e
a
d
s
×
(
N
+
1
)
×
C
h
e
a
d
q, k, v\in\Reals^{(n_{windows}\times B)\times C\times(N+1)}\xrightarrow{split}q, k, v\in\Reals^{(n_{windows}\times B)\times n_{heads}\times(N+1)\times C_{head}}
q,k,v∈R(nwindows×B)×C×(N+1)splitq,k,v∈R(nwindows×B)×nheads×(N+1)×Chead
where,
C
=
n
h
e
a
d
s
×
C
h
e
a
d
C=n_{heads}\times C_{head}
C=nheads×Chead
step 5 produce features with depth-wise attention
z
¨
ℓ
=
A
t
t
e
n
t
i
o
n
(
q
,
k
,
v
)
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
n
h
e
a
d
s
×
(
N
+
1
)
×
C
h
e
a
d
\ddot{z}^{\ell}=\mathrm{Attention}(q, k, v)\in\Reals^{(n_{windows}\times B)\times n_{heads}\times(N+1)\times C_{head}}
z¨ℓ=Attention(q,k,v)∈R(nwindows×B)×nheads×(N+1)×Chead
to sum up, depth-wise attention formulated as
D
W
A
(
z
ℓ
−
1
)
=
A
t
t
e
n
t
i
o
n
(
z
ℓ
−
1
⋅
W
Q
,
z
ℓ
−
1
⋅
W
K
,
z
ℓ
−
1
⋅
W
V
)
\mathrm{DWA}(z^{\ell-1})=\mathrm{Attention}(z^{\ell-1}\cdot W_Q, z^{\ell-1}\cdot W_K, z^{\ell-1}\cdot W_V)
DWA(zℓ−1)=Attention(zℓ−1⋅WQ,zℓ−1⋅WK,zℓ−1⋅WV)
window token embedding
aim model attention relationship among windows
straight solution employ all pixel tokens
⟹
\implies
⟹ huge computational cost
new solution window token embedding
⟹
\implies
⟹ negligible computational cost
- a fixed zero vector
- a learnable vector with initialization of zero
in implementation, window token is a 1D tensor with the same dimension as input
point-wise self-attention (PWA)
point-wise convolution: fuse information from different channels
point-wise attention: fuse information across windows
step 1 split window tokens and windowed features from DWA output
z
¨
ℓ
\ddot{z}^{\ell}
z¨ℓ
z
¨
ℓ
∈
R
(
n
w
i
n
d
o
w
s
×
B
)
×
n
h
e
a
d
s
×
(
N
+
1
)
×
C
h
e
a
d
→
s
l
i
c
e
w
t
˙
∈
R
B
×
n
h
e
a
d
s
×
n
w
i
n
d
o
w
s
×
C
h
e
a
d
+
z
˙
ℓ
∈
R
B
×
n
h
e
a
d
s
×
n
w
i
n
d
o
w
s
×
N
×
C
h
e
a
d
\ddot{z}^{\ell}\in\Reals^{(n_{windows}\times B)\times n_{heads}\times(N+1)\times C_{head}}\xrightarrow{slice}\dot{wt}\in\Reals^{B\times n_{heads}\times n_{windows}\times C_{head}}+\dot{z}^{\ell}\in\Reals^{B\times n_{heads}\times n_{windows}\times N\times C_{head}}
z¨ℓ∈R(nwindows×B)×nheads×(N+1)×Cheadslicewt˙∈RB×nheads×nwindows×Chead+z˙ℓ∈RB×nheads×nwindows×N×Chead
step 2 project window token into window query, key
w
t
˙
∈
R
B
×
n
h
e
a
d
s
×
n
w
i
n
d
o
w
s
×
C
h
e
a
d
→
n
o
r
m
+
a
c
t
→
c
o
n
v
→
s
p
l
i
t
q
w
,
k
w
∈
R
B
×
n
h
e
a
d
s
×
n
w
i
n
d
o
w
s
×
C
h
e
a
d
s
\dot{wt}\in\Reals^{B\times n_{heads}\times n_{windows}\times C_{head}}\xrightarrow{norm+act}\xrightarrow{conv}\xrightarrow{split}q_w, k_w\in\Reals^{B\times n_{heads}\times n_{windows}\times C_{heads}}
wt˙∈RB×nheads×nwindows×Cheadnorm+actconvsplitqw,kw∈RB×nheads×nwindows×Cheads
step 3 produce features with point-wise attention
z
^
ℓ
=
P
W
A
(
z
˙
ℓ
,
w
t
˙
)
=
A
t
t
e
n
t
i
o
n
(
q
,
k
,
z
˙
ℓ
)
∈
R
B
×
n
h
e
a
d
s
×
n
w
i
n
d
o
w
s
×
N
×
C
h
e
a
d
\hat{z}^{\ell}=\mathrm{PWA}(\dot{z}^{\ell}, \dot{wt})=\mathrm{Attention}(q, k, \dot{z}^{\ell})\in\Reals^{B\times n_{heads}\times n_{windows}\times N\times C_{head}}
z^ℓ=PWA(z˙ℓ,wt˙)=Attention(q,k,z˙ℓ)∈RB×nheads×nwindows×N×Chead
where, windowed features
z
˙
ℓ
\dot{z}^{\ell}
z˙ℓ from DWA output directly used as window value
to sum up, point-wise attention formulated as
P
W
A
(
z
˙
ℓ
,
w
t
˙
)
=
A
t
t
e
n
t
i
o
n
(
G
E
L
U
(
L
N
(
w
t
˙
)
)
⋅
W
Q
,
G
E
L
U
(
L
N
(
w
t
˙
)
)
⋅
W
K
,
z
˙
ℓ
)
\mathrm{PWA}(\dot{z}^{\ell}, \dot{wt})=\mathrm{Attention}(\mathrm{GELU}(\mathrm{LN}(\dot{wt}))\cdot W_Q, \mathrm{GELU}(\mathrm{LN}(\dot{wt}))\cdot W_K, \dot{z}^{\ell})
PWA(z˙ℓ,wt˙)=Attention(GELU(LN(wt˙))⋅WQ,GELU(LN(wt˙))⋅WK,z˙ℓ)
grouped self-attention (GSA)
A macro view of the similarities and differences between the depth-wise separable self-attention and the grouped self-attention.
transformer encoder
to sum up, each block formulated as
z
~
ℓ
=
C
o
n
c
a
t
(
z
ℓ
−
1
,
w
t
)
z
¨
ℓ
=
D
W
A
(
L
N
(
z
~
ℓ
)
)
z
˙
ℓ
,
w
t
˙
=
S
l
i
c
e
(
z
¨
ℓ
)
z
^
ℓ
=
P
W
A
(
z
˙
ℓ
,
w
t
˙
)
+
z
ℓ
−
1
z
ℓ
=
M
L
P
(
L
N
(
z
^
ℓ
)
)
+
z
^
ℓ
\begin{aligned} \tilde{z}^{\ell}&=\mathrm{Concat}(z^{\ell-1}, wt) \\ \ddot{z}^{\ell}&= \mathrm{DWA}(\mathrm{LN}(\tilde{z}^{\ell})) \\ \dot{z}^{\ell}, \dot{wt}&=\mathrm{Slice}(\ddot{z}^{\ell}) \\ \hat{z}^{\ell}&=\mathrm{PWA}(\dot{z}^{\ell}, \dot{wt})+z^{\ell-1} \\ z^{\ell}&=\mathrm{MLP}(\mathrm{LN}(\hat{z}^{\ell}))+\hat{z}^{\ell} \end{aligned}
z~ℓz¨ℓz˙ℓ,wt˙z^ℓzℓ=Concat(zℓ−1,wt)=DWA(LN(z~ℓ))=Slice(z¨ℓ)=PWA(z˙ℓ,wt˙)+zℓ−1=MLP(LN(z^ℓ))+z^ℓ
Complexity comparison of an information interaction within and among windows in a single SepViT block with those two-block pattern works in each stage.
reasons of low computational cost
- more lightweight
- remove many redundant layers
1 MLP + 2 LN in a single SepViT block, 2 MLP + 2 LN in two successive Swin or Twins blocks
architecture variants
Detailed configurations of SepViT variants in different stages.
Experiment
image classification
dataset ImageNet-1K
optimizer AdamW: batch size=1024, 300 epochs, init lr=1e-3, weigh decay=0.05 for for SepViT-T/S or 0.1 for SepViT-B, linear warm-up 5 epochs for SepViT-T/S or 20 epochs for SepViT-B, cosine decay
stochastic depth 0.2, 0.3, 0.5 for SepViT-T, SepViT-S, SepViT-B
Comparison of different state-of-the-art methods on ImageNet-1K classification. Throughput and latency are tested based on the PyTorch framework with a V100 GPU (batchsize=192) and TensorRT framework with a T4 GPU (batchsize=8).
object detection and instance segmentation
framework RetinaNet, Mask R-CNN
dataset COCO 2017
- 1x schedule
optimizer AdamW: batch size=16, 12 epochs, init lr=1e-4, weigh decay=1e-3 for for SepViT-T or 1e-4 for SepViT-S, warm-up 500 iterations, decay rate=0.1 at 8, 11-th epoch
stochastic depth 0.2, 0.3 for SepViT-T, SepViT-S - 3x-MS schedule
optimizer AdamW: batch size=16, 36 epochs, init lr=1e-4, weigh decay=0.05 for for SepViT-T or 0.1 for SepViT-S, warm-up 500 iterations, decay rate=0.1 at 27, 33-th epoch
stochastic depth 0.3 for SepViT-T/S
Comparison of different backbones on RetinaNet-based object detection task. FLOPs are measured with the input size of 800 × 1280 800\times1280 800×1280.
Comparison of different backbones on Mask R-CNN-based object detection and instance segmentation tasks. FLOPs are measured with the input size of 800 × 1280 800\times1280 800×1280. The superscript b b b and m m m denote the box detection and mask instance segmentation.
semantic segmentation
framework Semantic FPN
dataset ADE20K, ImageNet (pre-training)
optimizer AdamW: batch size=16, 80K iterations, init lr=1e-4, weigh decay=1e-4, polynomial lr decay (0.9)
stochastic depth 0.2, 0.3, 0.4 for SepViT-T, SepViT-S, SepViT-B
framework UperNet
dataset ADE20K, ImageNet (pre-training)
optimizer AdamW: batch size=16, 160K iterations, init lr=6e-5, weigh decay=0.01 for for SepViT-T/S or 0.03 for SepViT-B
stochastic depth 0.2, 0.3, 0.5 for SepViT-T, SepViT-S, SepViT-B
Comparison of different backbones on ADE20K semantic segmentation task. FLOPs are measured with the input size of 512 × 2048 512\times2048 512×2048.
ablation studies
efficient components
SepViT adopt conditional position encoding (CPE), overlapping patch embedding (OPE)
Swin-T+CPVT: taken as baseline
SepViT-T
†
\dag
†: with CPE but without OPE
Ablation studies of the key components in our SepViT. LWT means initializing the window tokens with learnable vectors.
window token embedding
window token initialized as
- a fixed zero vector
- a learnable vector with initialization of zero
schemes for learning global representation
- Win_Tokens: window token embedding
- Avg_Pool: average pooling
- Dw_Conv: depth-wise convolution
parameters and FLOPs comparison between Win_Token and Avg_Pool methods
⟹
\implies
⟹ window token embedding bring negligible computational cost
Comparison of different approaches of getting the global representation of each window in SepViT.
comparison with lite models
Comparison of lite models on ImageNet-1K classification.