[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers

paper
code

Contribution

  • propose Focal self-attention (FSA) with fine attention locally and coarse attention globally

Method

model architecture


Model architecture for our Focal Transformers. As highlighted in light blue boxes, our main innovation is the proposed focal self-attention mechanism in each Transformer layer.

focal self-attention (FSA)


Left: Visualization of the attention maps of the three heads at the given query patch (blue) in the first layer of the DeiT-Tiny model. Right: An illustrative depiction of focal self-attention mechanism. Three granularity levels are used to compose the attention region for the blue query.

FSA attend fine-grain tokens only locally instead of attending all tokens at fine-grain
cover as many regions as standard self-attention but with much less cost


The size of receptive field (yaxis) with the increase of used tokens (x-axis) for standard and our focal selfattention. For focal self-attention, we assume increasing the window granularity by factor 2 gradually but no more than 8. Note that the y-axis is logarithmic.

for a query position, when use gradually coarser-grain for its far surroundings, FSA have significantly larger receptive fields at the cost of attending the same number of visual tokens than baseline.
focal mechanism enable long-range self-attention with much less time and memory cost

window-wise attention


An illustration of our focal self-attention at window level. Each of the finest square cell represents a visual token either from the original feature map or the squeezed ones. Suppose we have an input feature map of size 20x20. We first partition it into 5x5 windows of size 4x4. Take the 4x4 blue window in the middle as the query, we extract its surroundings tokens at multiple granularity levels as its keys and values. For the first level, we extract the 8x8 tokens which are closest to the blue window at the finest grain. Then at the second level, we expand the attention region and pool the surrounding 2x2 sub-windows, which results in 6x6 pooled tokens. At the third level, we attend even larger region covering the whole feature map and pool 4x4 sub-windows. Finally, these three levels of tokens are concatenated to compute the keys and values for the 4x4=16 tokens (queries) in the blue window.

firstly define 3 terms for clarity

  1. focal levels L number of granularity levels that extract tokens for focal self-attention
  2. focal window size s w l {s_w}^l swl size of sub-window on which summarized tokens got at level l ∈ 1 , . . . , L l\in {1, ..., L} l1,...,L
  3. focal region size s r l {s_r}^l srl number of sub-windows horizontally and vertically in attended regions at level l

specify focal self-attention proceeded in 2 main steps

  1. sub-window pooling
    given input feature map x ∈ R h × w × C x\in R^{h\times w\times C} xRh×w×C, split into s w l × s w l {s_w}^l\times {s_w}^l swl×swl-size sub-windows
    x ^ = R e s h a p e ( x ) ∈ R h s w l × w s w l × C × ( s w l × s w l ) \widehat{x}=Reshape(x)\in R^{\frac h{{s_w}^l}\times \frac w{{s_w}^l}\times C\times({s_w}^l\times {s_w}^l)} x =Reshape(x)Rswlh×swlw×C×(swl×swl)
    use a linear layer to pool each sub-window spatially
    x l = f p l ( x ^ ) ∈ R h s w l × w s w l x^l={f_p}^l(\widehat{x})\in R^{\frac h{{s_w}^l}\times \frac w{{s_w}^l}} xl=fpl(x )Rswlh×swlw
  2. attention computation
    obtained pooled feature maps { x l } 1 L {{\{x^l\}}_1}^L {xl}1L, compute q, k, v with linear projection layers
    Q = f q ( x 1 ) K = { K l } 1 L = f k ( x 1 , . . . , x L ) V = { V l } 1 L = f v ( x 1 , . . . , x L ) \begin{aligned} Q&=f_q(x^1) \\ K&={{\{K^l\}}_1}^L=f_k({x^1, ..., x^L}) \\ V&={{\{V^l\}}_1}^L=f_v({x^1, ..., x^L}) \end{aligned} QKV=fq(x1)={Kl}1L=fk(x1,...,xL)={Vl}1L=fv(x1,...,xL)
    first extract surrounding tokens for each query token in feature map
    note that tokens inside a window partition s p × s p s_p\times s_p sp×sp share the same set of surroundings
    for queries in i-th window Q i ∈ R s p × s p × C Q_i\in R^{s_p\times s_p\times C} QiRsp×sp×C, extract s r l × s r l {s_r}^l\times {s_r}^l srl×srl keys, values from K l K_l Kl, V l V_l Vl around the window where query lie in
    then gather keys and values for all L levels to obtain
    K i = K i 1 , . . . , K i L ∈ R s × C , V i = V i 1 , . . . , V i L ∈ R s × C K_i={{K_i}^1, ..., {K_i}^L}\in R^{s\times C}, V_i={{V_i}^1, ..., {V_i}^L}\in R^{s\times C} Ki=Ki1,...,KiLRs×C,Vi=Vi1,...,ViLRs×C
    where, s is sum of focal regions from all levels, i.e., s = ∑ l = 1 L ( s r l ) 2 s=\sum_{l=1}^L({s_r}^l)^2 s=l=1L(srl)2
    note that a strict version of focal self-attention requires to exclude overlapped regions across different levels
    finally, include a relative position bias and compute focal self-attention
    A t t e n t i o n ( Q i , K i , V i ) = s o f t m a x ( Q i K i T d + B ) V i Attention(Q_i, K_i, V_i)=softmax(\frac {Q_iK_i^T}{\sqrt{d}}+B)V_i Attention(Qi,Ki,Vi)=softmax(d QiKiT+B)Vi
    where, B = { B l } 1 L B={{\{B^l\}}_1}^L B={Bl}1L is a learnable relative position bias, consisting of L subsets for L focal levels
  • for the first level, parameterize B to B 1 ∈ R ( 2 s p − 1 ) × ( 2 s p − 1 ) B_1\in R^{(2s_p-1)\times(2s_p-1)} B1R(2sp1)×(2sp1)
    where, horizontal and vertical position range in [- s p s_p sp+1, s p s_p sp-1]
  • for the other levels, because of different granularity to queries, treat all queries inside a window equally
    use B l ∈ R s r l × s r l B_l\in R^{{s_r}^l\times {s_r}^l} BlRsrl×srl to represent relative position bias between query window, each of s r l × s r l {s_r}^l\times {s_r}^l srl×srl pooled token
focal transformer encoder

with encoder blocks containing FSA, transformer encoder computed as
z ^ l = F S A ( L N ( z l − 1 ) ) + z l − 1 z l = F F N ( L N ( z ^ l ) ) + z ^ l \begin{aligned} \widehat{z}_l&=FSA(LN(z_{l-1}))+z_{l-1} \\ z_l&=FFN(LN(\widehat{z}_l))+\widehat{z}_l \end{aligned} z lzl=FSA(LN(zl1))+zl1=FFN(LN(z l))+z l

computational complexity

in ViT, given input feature map x ∈ R h × w × C x\in R^{h\times w\times C} xRh×w×C, FLOPs of MSA is
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega(MSA)=4hwC^2+2(hw)^2C Ω(MSA)=4hwC2+2(hw)2C
given input feature map x ∈ R h × w × C x\in R^{h\times w\times C} xRh×w×C, h s p × w s p \frac h{s_p}\times \frac w{s_p} sph×spw sub-windows at focal level l
for pooling on each s w l × s w l {s_w}^l\times {s_w}^l swl×swl-size sub-window
Ω ( p o o l ) = ( s w l ) 2 C \Omega(pool)=({s_w}^l)^2C Ω(pool)=(swl)2C
for aggregation of sub-windows in h × w h\times w h×w feature map of each layer
Ω ( a g g r ) = h w C \Omega(aggr)=hwC Ω(aggr)=hwC
attention cost for a s p × s p s_p\times s_p sp×sp-size query window
Ω ( a t t n w i n ) = ( s p ) 2 C ∑ l ( s r l ) 2 \Omega(attn_{win})=(s_p)^2C\sum_{l}({s_r}^l)^2 Ω(attnwin)=(sp)2Cl(srl)2
attention cost in whole feature map
Ω ( a t t n f e a t ) = h w C ∑ l ( s r l ) 2 \Omega(attn_{feat})=hwC\sum_{l}({s_r}^l)^2 Ω(attnfeat)=hwCl(srl)2
to sum up, for FSA
Ω ( F S A ) = n l e v e l s × Ω ( a g g r ) + Ω ( a t t n f e a t ) = h w C ( L + ∑ l ( s r l ) 2 ) \Omega(FSA)=n_{levels}\times\Omega(aggr)+\Omega(attn_{feat})=hwC(L+\sum_{l}({s_r}^l)^2) Ω(FSA)=nlevels×Ω(aggr)+Ω(attnfeat)=hwC(L+l(srl)2)

architecture variants


Model configurations for our focal Transformers. We introduce three configurations Focal-Tiny, Focal-Small and Focal-Base with different model capacities.

Experiment

image classification

dataset ImageNet-1K, with augmentation and regularization as DeiT
optimizer AdamW: batchsize=1024, 300 epochs, init lr=1e-3, weigh decay=0.05, linear warm-up 20 epochs, cosine decay
stochastic depth 0.2, 0.2, 0.3 for Focal-T, Focal-S, Focal-B
max gradient norm clipped to 5.0


Comparison of image classification on ImageNet-1K for different models. Except for ViT-Base/16, all other models are trained and evaluated on 224x224 resolution.

object detection and instance segmentation

framework Mask R-CNN, Cascade Mask R-CNN
dataset COCO 2017
optimizer AdamW: 12 or 36 epochs, init lr=1e-4, weigh decay=0.05
stochastic depth 0.2, 0.2, 0.3 for Focal-T, Focal-S, Focal-B


Comparisons with CNN and Transformer baselines and SoTA methods on COCO object detection. The box mAP ( A P b AP^b APb) and mask mAP ( A P m AP^m APm) are reported for RetinaNet and Mask R-CNN trained with 1x schedule.


COCO object detection and segmentation results with RetinaNet and Mask R-CNN. All models are trained with 3x schedule and multi-scale inputs (MS). The numbers before and after “/” at column 2 and 3 are the model size and complexity for RetinaNet and Mask R-CNN, respectively.

dataset COCO 2017
optimizer AdamW: 36 epochs, init lr=1e-4, weigh decay=0.05
stochastic depth 0.2, 0.2, 0.3 for Focal-T, Focal-S, Focal-B


Comparison with ResNet-50, Swin-Tiny across different object detection methods. We use Focal-Tiny as the backbone and train all models using 3x schedule.

semantic segmentation

dataset ADE20K
optimizer AdamW: batchsize=16, 160K iterations, init lr=6e-5, weigh decay=0.01, polynomial decay
scaling ratio [0.5, 0.75, 1.0, 1.25, 1.5, 1.75], for multi-scale evaluation


Comparison with SoTA methods for semantic segmentation on ADE20K val set. Both single- and multi-scale evaluations are reported at the last two columns. “\neq” means pretrained on ImageNet-22K.

ablation study

window size
one question is that whether increasing window size further help model learning giving enlarged receptive fields


Impact of different window sizes (WSize). We alter the default size 7 to 14 and observe consistent improvements for both methods.

necessity of window shift
window shift operations enable cross-window interactions between two successive layers


Impact of window shift (W-Shift) on Swin Transformer and Focal Transformer. Tiny models are used.

short- and long-interaction
ablate Focal-Tiny model to

  1. Focal-Tiny-Window merely performing attention inside each window
  2. Focal-Tiny-Local attending additional fine-grain surrounding tokens
  3. Focal-Tiny-Global attending extra coarse-grain squeezed tokens


Ablating Focal-Tiny model by adding local, global and both interactions, respectively. Blue bars are for image classification and orange bars indicate object detection performance. Both local and global interactions are essential to obtain good performance.

model depth
since focal attention prompt local and global interactions at each Transformer layer, one question is that whether less number of layers needed to obtain similar modeling capacity as those without global interactions
reduce number of Transformer layers at stage 3 in Swin-Tiny, Focal-Tiny from 6 to 4, 2


Impact of the change of model depth. We gradually reduce the number of transformer layers at the third stage from original 6 to 4 and further 2. It apparently hurts the performance but our Focal Transformers has much slower drop rate than Swin Transformer.

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值