[2203] SepViT: Separable Vision Transformer

paper

Abstract

  • propose Separable Vision Transformer (SepViT) with depth-wise separable self-attention
    local information interaction within windows
    global information exchange among windows
  • propose window token embedding
    learn global feature representations of each window
    model attention relationship among windows with negligible computational cost
  • extend depth-wise separable self-attention to grouped self-attention
    capture more contextual concepts across multiple windows

2203_sepvit_f1
Comparison of throughput and latency on ImageNet-1K classification. The throughput and the latency are tested based on the PyTorch framework with a V100 GPU and TensorRT framework with a T4 GPU, respectively.

Method

model architecture

2203_sepvit_f2
Separable Vision Transformer (SepViT). The top row is the overall hierarchical architecture of SepViT. The bottom row is the SepViT block and the detailed visualization of our depth-wise separable self-attention and the window token embedding scheme.

depth-wise separable self-attention (DSSA)
depth-wise self-attention (DWA)

depth-wise convolution: fuse spatial information within each channel
depth-wise attention: fuse spatial information within each window

step 1 partition input features z ℓ − 1 z^{\ell-1} z1 into windows
z ℓ − 1 ∈ R B × C × H × W → p a r t i t i o n z ℓ − 1 ∈ R ( n w i n d o w s × B ) × C × N z^{\ell-1}\in\Reals^{B\times C\times H\times W}\xrightarrow{partition}z^{\ell-1}\in\Reals^{(n_{windows}\times B)\times C\times N} z1RB×C×H×Wpartition z1R(nwindows×B)×C×N

where, n w i n d o w s = H H w i n d o w × W W w i n d o w , N = H w i n d o w × W w i n d o w n_{windows}=\frac{H}{H_{window}}\times\frac{W}{W_{window}}, N=H_{window}\times W_{window} nwindows=HwindowH×WwindowW,N=Hwindow×Wwindow
step 2 concatenate window tokens w t wt wt and windowed features
w t ∈ R ( n w i n d o w s × B ) × C × 1 + z ℓ − 1 ∈ R ( n w i n d o w s × B ) × C × N → c o n c a t z ~ ℓ ∈ R ( n w i n d o w s × B ) × C × ( N + 1 ) wt\in\Reals^{(n_{windows}\times B)\times C\times1}+z^{\ell-1}\in\Reals^{(n_{windows}\times B)\times C\times N}\xrightarrow{concat}\tilde{z}^{\ell}\in\Reals^{(n_{windows}\times B)\times C\times(N+1)} wtR(nwindows×B)×C×1+z1R(nwindows×B)×C×Nconcat z~R(nwindows×B)×C×(N+1)

step 3 project features into query, key, value
z ~ ℓ ∈ R ( n w i n d o w s × B ) × C × ( N + 1 ) → l i n e a r q k v ∈ R ( n w i n d o w s × B ) × 3 C × ( N + 1 ) → s p l i t q , k , v ∈ R ( n w i n d o w s × B ) × C × ( N + 1 ) \tilde{z}^{\ell}\in\Reals^{(n_{windows}\times B)\times C\times(N+1)}\xrightarrow{linear}qkv\in\Reals^{(n_{windows}\times B)\times3C\times(N+1)}\xrightarrow{split}q, k, v\in\Reals^{(n_{windows}\times B)\times C\times(N+1)} z~R(nwindows×B)×C×(N+1)linear qkvR(nwindows×B)×3C×(N+1)split q,k,vR(nwindows×B)×C×(N+1)

step 4 split query, key, value into multi-head version
q , k , v ∈ R ( n w i n d o w s × B ) × C × ( N + 1 ) → s p l i t q , k , v ∈ R ( n w i n d o w s × B ) × n h e a d s × ( N + 1 ) × C h e a d q, k, v\in\Reals^{(n_{windows}\times B)\times C\times(N+1)}\xrightarrow{split}q, k, v\in\Reals^{(n_{windows}\times B)\times n_{heads}\times(N+1)\times C_{head}} q,k,vR(nwindows×B)×C×(N+1)split q,k,vR(nwindows×B)×nheads×(N+1)×Chead

where, C = n h e a d s × C h e a d C=n_{heads}\times C_{head} C=nheads×Chead
step 5 produce features with depth-wise attention
z ¨ ℓ = A t t e n t i o n ( q , k , v ) ∈ R ( n w i n d o w s × B ) × n h e a d s × ( N + 1 ) × C h e a d \ddot{z}^{\ell}=\mathrm{Attention}(q, k, v)\in\Reals^{(n_{windows}\times B)\times n_{heads}\times(N+1)\times C_{head}} z¨=Attention(q,k,v)R(nwindows×B)×nheads×(N+1)×Chead

to sum up, depth-wise attention formulated as
D W A ( z ℓ − 1 ) = A t t e n t i o n ( z ℓ − 1 ⋅ W Q , z ℓ − 1 ⋅ W K , z ℓ − 1 ⋅ W V ) \mathrm{DWA}(z^{\ell-1})=\mathrm{Attention}(z^{\ell-1}\cdot W_Q, z^{\ell-1}\cdot W_K, z^{\ell-1}\cdot W_V) DWA(z1)=Attention(z1WQ,z1WK,z1WV)

window token embedding

aim model attention relationship among windows
straight solution employ all pixel tokens    ⟹    \implies huge computational cost
new solution window token embedding    ⟹    \implies negligible computational cost

  • a fixed zero vector
  • a learnable vector with initialization of zero

in implementation, window token is a 1D tensor with the same dimension as input

point-wise self-attention (PWA)

point-wise convolution: fuse information from different channels
point-wise attention: fuse information across windows

step 1 split window tokens and windowed features from DWA output z ¨ ℓ \ddot{z}^{\ell} z¨
z ¨ ℓ ∈ R ( n w i n d o w s × B ) × n h e a d s × ( N + 1 ) × C h e a d → s l i c e w t ˙ ∈ R B × n h e a d s × n w i n d o w s × C h e a d + z ˙ ℓ ∈ R B × n h e a d s × n w i n d o w s × N × C h e a d \ddot{z}^{\ell}\in\Reals^{(n_{windows}\times B)\times n_{heads}\times(N+1)\times C_{head}}\xrightarrow{slice}\dot{wt}\in\Reals^{B\times n_{heads}\times n_{windows}\times C_{head}}+\dot{z}^{\ell}\in\Reals^{B\times n_{heads}\times n_{windows}\times N\times C_{head}} z¨R(nwindows×B)×nheads×(N+1)×Cheadslice wt˙RB×nheads×nwindows×Chead+z˙RB×nheads×nwindows×N×Chead

step 2 project window token into window query, key
w t ˙ ∈ R B × n h e a d s × n w i n d o w s × C h e a d → n o r m + a c t → c o n v → s p l i t q w , k w ∈ R B × n h e a d s × n w i n d o w s × C h e a d s \dot{wt}\in\Reals^{B\times n_{heads}\times n_{windows}\times C_{head}}\xrightarrow{norm+act}\xrightarrow{conv}\xrightarrow{split}q_w, k_w\in\Reals^{B\times n_{heads}\times n_{windows}\times C_{heads}} wt˙RB×nheads×nwindows×Cheadnorm+act conv split qw,kwRB×nheads×nwindows×Cheads

step 3 produce features with point-wise attention
z ^ ℓ = P W A ( z ˙ ℓ , w t ˙ ) = A t t e n t i o n ( q , k , z ˙ ℓ ) ∈ R B × n h e a d s × n w i n d o w s × N × C h e a d \hat{z}^{\ell}=\mathrm{PWA}(\dot{z}^{\ell}, \dot{wt})=\mathrm{Attention}(q, k, \dot{z}^{\ell})\in\Reals^{B\times n_{heads}\times n_{windows}\times N\times C_{head}} z^=PWA(z˙,wt˙)=Attention(q,k,z˙)RB×nheads×nwindows×N×Chead

where, windowed features z ˙ ℓ \dot{z}^{\ell} z˙ from DWA output directly used as window value
to sum up, point-wise attention formulated as
P W A ( z ˙ ℓ , w t ˙ ) = A t t e n t i o n ( G E L U ( L N ( w t ˙ ) ) ⋅ W Q , G E L U ( L N ( w t ˙ ) ) ⋅ W K , z ˙ ℓ ) \mathrm{PWA}(\dot{z}^{\ell}, \dot{wt})=\mathrm{Attention}(\mathrm{GELU}(\mathrm{LN}(\dot{wt}))\cdot W_Q, \mathrm{GELU}(\mathrm{LN}(\dot{wt}))\cdot W_K, \dot{z}^{\ell}) PWA(z˙,wt˙)=Attention(GELU(LN(wt˙))WQ,GELU(LN(wt˙))WK,z˙)

grouped self-attention (GSA)

2203_sepvit_f3
A macro view of the similarities and differences between the depth-wise separable self-attention and the grouped self-attention.

transformer encoder

to sum up, each block formulated as
z ~ ℓ = C o n c a t ( z ℓ − 1 , w t ) z ¨ ℓ = D W A ( L N ( z ~ ℓ ) ) z ˙ ℓ , w t ˙ = S l i c e ( z ¨ ℓ ) z ^ ℓ = P W A ( z ˙ ℓ , w t ˙ ) + z ℓ − 1 z ℓ = M L P ( L N ( z ^ ℓ ) ) + z ^ ℓ \begin{aligned} \tilde{z}^{\ell}&=\mathrm{Concat}(z^{\ell-1}, wt) \\ \ddot{z}^{\ell}&= \mathrm{DWA}(\mathrm{LN}(\tilde{z}^{\ell})) \\ \dot{z}^{\ell}, \dot{wt}&=\mathrm{Slice}(\ddot{z}^{\ell}) \\ \hat{z}^{\ell}&=\mathrm{PWA}(\dot{z}^{\ell}, \dot{wt})+z^{\ell-1} \\ z^{\ell}&=\mathrm{MLP}(\mathrm{LN}(\hat{z}^{\ell}))+\hat{z}^{\ell} \end{aligned} z~z¨z˙,wt˙z^z=Concat(z1,wt)=DWA(LN(z~))=Slice(z¨)=PWA(z˙,wt˙)+z1=MLP(LN(z^))+z^

2203_sepvit_f4
Complexity comparison of an information interaction within and among windows in a single SepViT block with those two-block pattern works in each stage.

reasons of low computational cost

  • more lightweight
  • remove many redundant layers
    1 MLP + 2 LN in a single SepViT block, 2 MLP + 2 LN in two successive Swin or Twins blocks
architecture variants

2203_sepvit_t1
Detailed configurations of SepViT variants in different stages.

Experiment

image classification

dataset ImageNet-1K
optimizer AdamW: batch size=1024, 300 epochs, init lr=1e-3, weigh decay=0.05 for for SepViT-T/S or 0.1 for SepViT-B, linear warm-up 5 epochs for SepViT-T/S or 20 epochs for SepViT-B, cosine decay
stochastic depth 0.2, 0.3, 0.5 for SepViT-T, SepViT-S, SepViT-B

2203_sepvit_t2
Comparison of different state-of-the-art methods on ImageNet-1K classification. Throughput and latency are tested based on the PyTorch framework with a V100 GPU (batchsize=192) and TensorRT framework with a T4 GPU (batchsize=8).

object detection and instance segmentation

framework RetinaNet, Mask R-CNN
dataset COCO 2017

  • 1x schedule
    optimizer AdamW: batch size=16, 12 epochs, init lr=1e-4, weigh decay=1e-3 for for SepViT-T or 1e-4 for SepViT-S, warm-up 500 iterations, decay rate=0.1 at 8, 11-th epoch
    stochastic depth 0.2, 0.3 for SepViT-T, SepViT-S
  • 3x-MS schedule
    optimizer AdamW: batch size=16, 36 epochs, init lr=1e-4, weigh decay=0.05 for for SepViT-T or 0.1 for SepViT-S, warm-up 500 iterations, decay rate=0.1 at 27, 33-th epoch
    stochastic depth 0.3 for SepViT-T/S

2203_sepvit_t4
Comparison of different backbones on RetinaNet-based object detection task. FLOPs are measured with the input size of 800 × 1280 800\times1280 800×1280.

2203_sepvit_t5
Comparison of different backbones on Mask R-CNN-based object detection and instance segmentation tasks. FLOPs are measured with the input size of 800 × 1280 800\times1280 800×1280. The superscript b b b and m m m denote the box detection and mask instance segmentation.

semantic segmentation

framework Semantic FPN
dataset ADE20K, ImageNet (pre-training)
optimizer AdamW: batch size=16, 80K iterations, init lr=1e-4, weigh decay=1e-4, polynomial lr decay (0.9)
stochastic depth 0.2, 0.3, 0.4 for SepViT-T, SepViT-S, SepViT-B

framework UperNet
dataset ADE20K, ImageNet (pre-training)
optimizer AdamW: batch size=16, 160K iterations, init lr=6e-5, weigh decay=0.01 for for SepViT-T/S or 0.03 for SepViT-B
stochastic depth 0.2, 0.3, 0.5 for SepViT-T, SepViT-S, SepViT-B

2203_sepvit_t3
Comparison of different backbones on ADE20K semantic segmentation task. FLOPs are measured with the input size of 512 × 2048 512\times2048 512×2048.

ablation studies
efficient components

SepViT adopt conditional position encoding (CPE), overlapping patch embedding (OPE)
Swin-T+CPVT: taken as baseline
SepViT-T † \dag : with CPE but without OPE

2203_sepvit_t6
Ablation studies of the key components in our SepViT. LWT means initializing the window tokens with learnable vectors.

window token embedding

window token initialized as

  • a fixed zero vector
  • a learnable vector with initialization of zero

schemes for learning global representation

  • Win_Tokens: window token embedding
  • Avg_Pool: average pooling
  • Dw_Conv: depth-wise convolution

parameters and FLOPs comparison between Win_Token and Avg_Pool methods
   ⟹    \implies window token embedding bring negligible computational cost

2203_sepvit_t7
Comparison of different approaches of getting the global representation of each window in SepViT.

comparison with lite models

2203_sepvit_t8
Comparison of lite models on ImageNet-1K classification.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值