[2106] [NIPS 2021] Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer

paper
code

Contribution

  • propose Shuffle self-attention with spatial shuffle for cross-window connection

Method

model architecture


The architecture of a Shuffle Transformer (Shuffle-T).

shuffle window self-attention
window-based self-attention

feature map x ∈ R h × w × C x\in R^{h\times w\times C} xRh×w×C divided into w s × w s ws\times ws ws×ws-size windows
self-attention applied on each sub-windows who contain w s × w s ws\times ws ws×ws tokens, respectively

spatial shuffle for cross-window connection

problem of window-based self-attention: limited receptive field in window, especially on HR input
solution introduce spatial shuffle inpired by channel shuffle from ShuffleNet


Spatial shuffle with two stacked window-based Transformer block. The MLP is omitted in the visualization because it does not affect the information interaction in the spatial dimension. WMSA stands for window-based multi-head self-attention. a) two stacked window-based Transformer blocks with the same window size. Each output token only relates to the tokens within the window. No cross-talk; b) tokens from different windows are fully related when WMSA2 takes data from different windows after WMSA1; c) an equivalent implementation to b) using spatial shuffle and alignment.

given window size is ws and tokens number in a window is N
spatial shuffle obtain input data from different windows
reshape spatial dimension into ( w s , N w s ) (ws, \frac N{ws}) (ws,wsN), then transpose it into ( N w s , w s ) (\frac N{ws}, ws) (wsN,ws) and flatten it back
spatial alignment adjust spatial tokens into original position to ensure spatial alignment of features and image content
reshape spatial dimension into ( N w s , w s ) (\frac N{ws}, ws) (wsN,ws), then transpose it into ( w s , N w s ) (ws, \frac N{ws}) (ws,wsN) and flatten it back
[ a 11 . . . a 1 j , a 21 . . . a 2 j , . . . , a i 1 . . . a i j ] ⇌ a l i g n m e n t s h u f f l e [ a 11 . . . a 1 j , a 21 . . . a 2 j , . . . , a i 1 . . . a i j ] [a_{11}...a_{1j}, a_{21}...a_{2j},..., a_{i1}...a_{ij}] \overset{shuffle}{\underset{alignment}{\rightleftharpoons}} [a_{11}...a_{1j}, a_{21}...a_{2j},..., a_{i1}...a_{ij}] [a11...a1j,a21...a2j,...,ai1...aij]alignmentshuffle[a11...a1j,a21...a2j,...,ai1...aij]

neighbor-window connection (NWC) enhancement

spatial shuffle in window-based self-attention build cross-window connections, especially long-range cross-window
problem “grid issue”, when processing a HR image whose size much greater than window size
approaches to enhance neighbor-window connection

  1. enlarge window size
  2. use shifted window
  3. introduce conv to shuffle transformer block

implemented by a depth-wise conv with a skip-connection between WMSA and MLP, whose kernel size the same as window size
strengthen information flow among nearby windows, thus alleviating “grid issue”

shuffle transformer encoder


Two successive Shuffle Transformer Block. The WMSA and Shuffle WMSA are windowbased multi-head self attention without/with spatial shuffle, respectively.

with consecutive encoder blocks alternating between (Shuffle-)WMSA, transformer encoder computed as
z ^ l = W M S A ( L N ( z l − 1 ) ) + z l − 1 z ^ l = N W C ( z ^ l ) + z ^ l z l = F F N ( L N ( z ^ l ) ) + z ^ l z ^ l + 1 = S h u f f l e − W M S A ( L N ( z l ) ) + z l z ^ l + 1 = N W C ( z ^ l + 1 ) + z ^ l + 1 z l + 1 = F F N ( L N ( z ^ l + 1 ) ) + z ^ l + 1 \begin{aligned} \widehat{z}_l&=WMSA(LN(z_{l-1}))+z_{l-1} \\ \widehat{z}_l&=NWC(\widehat{z}_l)+\widehat{z}_l \\ z_l&=FFN(LN(\widehat{z}_l))+\widehat{z}_l \\ \widehat{z}_{l+1}&=Shuffle-WMSA(LN(z_l))+z_l \\ \widehat{z}_{l+1}&=NWC(\widehat{z}_{l+1})+\widehat{z}_{l+1} \\ z_{l+1}&=FFN(LN(\widehat{z}_{l+1}))+\widehat{z}_{l+1} \end{aligned} z lz lzlz l+1z l+1zl+1=WMSA(LN(zl1))+zl1=NWC(z l)+z l=FFN(LN(z l))+z l=ShuffleWMSA(LN(zl))+zl=NWC(z l+1)+z l+1=FFN(LN(z l+1))+z l+1
where, N W C ( . ) NWC(.) NWC(.) is neighbor-window connection operation

computational complexity

in ViT, given input feature map x ∈ R h × w × C x\in R^{h\times w\times C} xRh×w×C, FLOPs of MSA is
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega(MSA)=4hwC^2+2(hw)^2C Ω(MSA)=4hwC2+2(hw)2C
for (Shuffle-)WMSA, replace h × w h\times w h×w with window size w s × w s ws\times ws ws×ws and batchsize × n w i n d o w s = h w ( w s ) 2 \times n_{windows}=\frac {hw}{(ws)^2} ×nwindows=(ws)2hw
Ω ( ( S h u f f l e − ) W M S A ) = 4 h w C 2 + 2 h w ( w s ) 2 C \Omega((Shuffle-)WMSA)=4hwC^2+2hw(ws)^2C Ω((Shuffle)WMSA)=4hwC2+2hw(ws)2C

architecture variants

window size: M=7
query dimension of each head: d=32
expansion layer of each MLP: α \alpha α=4

architecture hyper-parameters of model variants

  • Shuffle-T: C = 96, layer numbers = {2, 2, 6, 2}
  • Shuffle-S: C = 96, layer numbers = {2, 2, 18, 2}
  • Shuffle-B: C = 128, layer numbers = {2, 2, 18, 2}

where, C is channel number of hidden layers in the first stage

Experiment

image classification

dataset ImageNet-1K, with augmentation and regularization as Swin
optimizer AdamW: batchsize=1024, epoch=300, init lr=1e-3, weigh decay=0.05, cosine decay, linear warm-up 20 epochs


Comparison of different backbones on ImageNet-1K classification. Throughput is measured with the batch size of 192 on a single V100 GPU. All models are trained and evaluated on 224x224 resolution.

object detection and instance segmentation

framework Mask R-CNN, Cascade Mask R-CNN
dataset COCO 2017
optimizer AdamW: batchsize=16, 36 epochs, init lr=1e-4, weigh decay=0.05


Object detection and instance segmentation performance on the COCO val2017 dataset using the Mask R-CNN and Cascade Mask R-CNN framework. FLOPs is evaluated on 1280x800 resolution.

semantic segmentation

framework UPerNet
dataset ADE20K, with augmentation
optimizer AdamW: batchsize=16, 1500 iterations, init lr=6e-5, weigh decay=0.01, cosine decay, linear warm-up 1500 iterations


Results of semantic segmentation on the ADE20K validation set. “+” indicates that the model is pretrained on ImageNet-22K. FLOPs is measured on 1024x1024 resolution. “*” indicates the FPS reproduced by us and is measured on 512x512 resolution.

ablation study

effect of spatial shuffle and NWC


Ablation study on the effect of spatial shuffle and the neighbor-window connection on two benchmarks, FLOPs is measured on 224x224 resolution.

way of spatial shuffle
3 kinds of spatial shuffle

  1. long-range spatial shuffle
  2. short-range spatial shuffle: reshape output spatial dimension into ( N 2 M , M , 2 ) (\frac N{2M}, M, 2) (2MN,M,2)
  3. random spatial shuffle: reshape output spatial dimension randomly


Ablation study on different ways to spatial shuffle on two benchmarks.

long-range spatial shuffle perform best on classification and segmentation tasks
random spatial shuffle achieve comparable performance

position of NWC module in encoder


Left: Visualization of three different positions to insert the neighbor-window connection. A: before the shuffle WMSA; B: after the residual connection of the shuffle WMSA; C: inside the MLP block. Right: Ablation study on the effect of the neighbor-window connection inserted at different positions, where A, B and C refer to three positions depicted left, and “w/o NWC” means no neighbor-window connection is inserted. FLOPs is measured on 224x224 resolution.

NWC between Shuffle-WMSA and MLP (position B) achieve the best performance

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值