[论文评析] ArXiv,2021, Focal Self Attention技术分析

动机

通过整合fine-grained local attention和coarse-grained golbal attention,来克服各自的问题。

Focal self Attention

在这里插入图片描述
所谓的Focal self attention简单来说就是对距离Query越近的区域进行细粒度fine granulity的attention, 对距离Query越远的区域进行粗粒度的attention,通过调整粒度的级别最终会得到一个层次化的feature maps, 然后把这些feature map经过Flatten、Concatenation等操作转化为Vector, 然后对Vector分别进行Linear projection得到 Key和Value, 然后进行常规的Scaled Dot-Product Attention (SDPA)即可。

基本概念

为了控制粒度,作者引入了三个基本定义,分别为:
1.Focal level L L L: 用于进行Focal self attention的粒度级别的数量;
2.Focal Window size s w l s_{w}^{l} swl: 第 l l l个粒度级别上用于提取tokens的sub-window的尺寸;
3.Focal Region size s r l s_{r}^{l} srl: 第 l l l个粒度级别上水平/垂直方向上包含的sub-window的数量;

流程

1.网格化
假设Feature map的原始尺寸为 [ H , W , d ] [H, W, d] [H,W,d], 常规做法是对每个token都要进行一次attention, time和memory开销太大, 因此通过网格化将feature map分成一个个的sub windows, 然后逐window进行attention, 即所谓的window-wise self attention。
假设sub window的尺寸为 [ s p , s p ] [s_{p}, s_{p}] [sp,sp], 则总的windows数量为 H W s p s p \frac{HW}{s_{p}s_{p}} spspHW, 记网格化后的尺寸为 [ s p , s p , H / s p ⋅ W / s p ⋅ d ] [sp, sp, H/s_{p} \cdot W/s_{p} \cdot d] [sp,sp,H/spW/spd].
注:通过设置不同的大小的 s p s_{p} sp(实际上就是 s w l s_{w}^{l} swl), 可以得到不同粒度级别的surroundings.
2. 子窗口池化
假设输入: X ∈ R H ⋅ M ⋅ d X \in R^{H \cdot M \cdot d} XRHMd,
对于第 l l l个粒度级别, 其网格化操作形式化定义为:
X ^ = R e s h a p e ( X ) ∈ R s w l x s w l ( H / s w l ⋅ W / s w l ⋅ d ) \hat{X}=Reshape(X) \in R^{s_{w}^{l} {\rm x} s_{w}^{l} {\rm} (H/s_{w}^{l} \cdot W/s_{w}^{l} \cdot d)} X^=Reshape(X)Rswlxswl(H/swlW/swld),
然后池化操作形式化为:
x l = f p l ∈ R H / s w l ⋅ W / s w l ⋅ d x^{l}=f_{p}^{l} \in R^{H/s_{w}^{l} \cdot W/s_{w}^{l} \cdot d} xl=fplRH/swlW/swld.
然后遍历 l , l ∈ 1 , 2 , . . . , h l, l \in {1,2,...,h} l,l1,2,...,h, 最终得到输入feature map的层次化surroundings: { x l } l = 1 h \{x^{l}\}_{l=1}^{h} {xl}l=1h.
3.聚合操作
x l x^{l} xl为feature map,首先通过Flatten操作将其转化为Vector, 然后再把 h h h个这样的Vector进行Concat操作, 最终得到的tokens向量为 x t o t a l ∈ R s x d x_{total} \in R^{s {\rm x} d} xtotalRsxd, 其中 s = ∑ l = 1 h ( s w l ) 2 s=\sum_{l=1}^{h} (s_{w}^{l})^{2} s=l=1h(swl)2.

4.计算Query, Key和Value
通过分别进行Linear projection即可得到Query, Key和Value,形式化定义如下:
Q = f q ( x 1 ) Q=f_{q}(x^{1}) Q=fq(x1)
K = f k ( x t o t a l ) K=f_{k}(x_{total}) K=fk(xtotal)
V = f v ( x t o t a l ) V=f_{v}(x_{total}) V=fv(xtotal)

5.Attention计算
对于位于第 i i i个sub-window Q i ∈ R s p ⋅ s p ⋅ d Q_{i} \in R^{s_{p} {\cdot s_{p} \cdot d}} QiRspspd的Query, 其对应的Key和Value分别记为 K i ∈ R s ⋅ d K_{i} \in R^{s \cdot d} KiRsd, V i ∈ R s ⋅ d V_{i} \in R^{s \cdot d} ViRsd.
计算公式如下:
A t t e n t i o n ( Q i , K i , V i ) = s o f t m a x ( Q i K i T d + B ) V i Attention(Q_{i}, K_{i}, V_{i})=softmax(\frac{Q_{i} K_{i}^{T}}{\sqrt{d}+B}) V_{i} Attention(Qi,Ki,Vi)=softmax(d +BQiKiT)Vi
最终输出的尺寸依然为 R s p ⋅ s p ⋅ d R^{s_{p} {\cdot s_{p} \cdot d}} Rspspd
然后遍历 i , i ∈ { 1 , 2 , . . , H / s p ⋅ W / s p } i, i \in\{1,2,.., H/s_{p} \cdot W/s_{p}\} i,i{1,2,..,H/spW/sp}重复上述操作即可。

时间复杂度分析

网格化实际上就是对数据进行Reshape, 假设原始输入 X ∈ R M x N x d X \in R^{M {\rm x} N {\rm x} d} XRMxNxd
,Reshape之后第 l l l X ^ ∈ R s w l x s w l x ( M s w l ⋅ N s w l ⋅ d ) \hat{X} \in R^{s_{w}^{l} {\rm x} s_{w}^{l} {\rm x} (\frac{M}{s_{w}^{l}} \cdot \frac{N}{s_{w}^{l}} \cdot d)} X^Rswlxswlx(swlMswlNd)
对其池化的时间复杂度为 O ( s w l ⋅ s w l M s w l ⋅ N s w l ⋅ d ) = O ( M ⋅ N ⋅ d ) O(s_{w}^{l} \cdot s_{w}^{l} \frac{M}{s_{w}^{l}} \cdot \frac{N}{s_{w}^{l}} \cdot d)=O(M\cdot N \cdot d) O(swlswlswlMswlNd)=O(MNd), 对所有L个层池化的时间复杂度为 O ( O ( L ⋅ M ⋅ N ⋅ d ) ) O(O(L \cdot M\cdot N \cdot d)) O(O(LMNd)).

每个Query所属sub-window Q i ∈ R s p x s p x d Q_{i} \in R^{s_{p} {\rm x} s_{p} {\rm x} d} QiRspxspxd, 对应 K i , V i ∈ R s x d K_{i}, V_{i} \in R^{s {\rm x} d} Ki,ViRsxd, 因此计算Attention的时间复杂度为 O ( ( s p ) 2 ⋅ d ⋅ s ) O((s_{p})^{2}\cdot d \cdot s) O((sp)2ds), 共有 M s p ⋅ N s p \frac{M}{s_{p}} \cdot \frac{N}{s_{p}} spMspN个这样的sub-window, 因此进行Attention总的时间复杂度为 M ⋅ N ⋅ d ⋅ ( ∑ 1 L ( s r l ) 2 ) M\cdot N \cdot d \cdot (\sum_{1}^{L} (s_{r}^{l})^{2}) MNd(1L(srl)2)

因此,总的时间复杂度为: O ( M ⋅ N ⋅ d ⋅ ( L + ∑ l = 1 L ( s r l ) 2 ) ) O(M\cdot N \cdot d \cdot (L + \sum_{l=1}^{L} (s_{r}^{l})^{2})) O(MNd(L+l=1L(srl)2)).

点评

1.fine-grained local attention不能抓住global的信息,coarse-grained golbal attention虽然能抓住global information但因为粒度粗,因此,这两种方式实际上都不能发挥出NLP中原始Transformer中attention的建模能力。
focal self attention 实际上正是对这两种attention的整合,因为它既能抓住closet surroundings的fine-grained local信息,同时又能抓住far surroundings的coarse-grained global信息。具体来说, Key和Value的来源不是某个单一粒度级别的tokens, 而是多个粒度级别的层次化的tokens的融合。
2.所谓的focal self attention最核心的东西可以看成是常规scaled dot-product attention的一个前置操作。

与Transformer中MHA的比较

1.Transformer中的MHA处理的是序列,对于长度为n,维度为d的序列 X ∈ R n x d m o d e l X \in R^{n {\rm x} d_{model}} XRnxdmodel, Q ∈ R n x d k Q \in R^{n {\rm x} d_{k}} QRnxdk, 对应 K ∈ R m x d k , V ∈ R m x d v K \in R^{m {\rm x} d_{k}}, V \in R^{m {\rm x} d_{v}} KRmxdk,VRmxdv, d k , d v d_{k}, d_{v} dk,dv完全取决于Linear projection输出的维度, m m m为K-V对的数量;而在focal self attention这里, Q i ∈ R s p ⋅ s p ⋅ d Q_{i} \in R^{s_{p} \cdot {s_{p} \cdot d}} QiRspspd, 其中 d d d为Figure的channel, s p s_{p} sp取决于网格化时设置的每个sub-window的宽度,对应 K i , V i ∈ R s ⋅ d K_{i}, V_{i} \in R^{s {\cdot} d} Ki,ViRsd, 其中s为所有L层池化结果经Flatten后再Concat的长度。
2.原始Transformer中为MHA, 通过多次Linear projection得到多个子空间,分别进行SDPA再聚合,论文中提到这既可以节省memory,提高efficancy,同时提升泛化能力; focal self attention中用的是网格化方法,Attention是在Window级别,而不是每个Query position, 这减少了memory, 节省了time.

几个疑点

1.论文中Fig.4 与正文中描述略不同, 图中是不同level的tokens先concatenation后进行linear projection产生 K i , V i K_{i}, V_{i} Ki,Vi, 文本描述部分是对每个level的分别进行linear projection (即 f k , f v f_{k}, f_{v} fk,fv), 然后再concatenation.

2.Fig. 4中Query position所在区域恰好位于Feature map的中心,因此每个level的tokens都是方形, 然而当Query position位于边缘时,怎么办? 需要做填充吗? 从论文中没有看到类似信息。

3. s p s_{p} sp, s w l s_{w}^{l} swl s r l s_{r}^{l} srl之间应该满足的关系:
这一点论文中并没有明确说明。
很明显, { s w l } l = 1 L \{s_{w}^{l}\}_{l=1}^{L} {swl}l=1L应该都能被 s p s_{p} sp整除, 特别的,当 s w l = 1 s_{w}^{l}=1 swl=1时,粒度级别最细; 当 s w l = s p s_{w}^{l}=s_{p} swl=sp时, 粒度级别最粗。

三者之间的关系, Fig.4中Query所在的蓝色区域周围上下左右包含的尺寸为 s w l x s w l s_{w}^{l} {\rm x} {s_{w}^{l}} swlxswl的方格的数量都为2, 因此可以很容易得出结论: s r l = 4 + s p s w l s_{r}^{l}=4 + \frac{s_{p}}{s_{w}^{l}} srl=4+swlsp.

Focal Transformer

在这里插入图片描述
如图, Focal Transformer包含多个Stage: { S t a g e i } i = 1 4 \{\rm Stage i\}_{i=1}^{4} {Stagei}i=14, 每个 S t a g e i {\rm Stage i} Stagei里面通过Stack相同的building block组成,数量为 N i N_{i} Ni,每个building block主要由Focal Self-Attention和Multi-Layer Perceptron 组成, 最核心的就是上面提到的Focal self-attention.

可以看到,在每个Stage之前都有一个Patch Embedding (PE), PE实际上就是一个Convolution层, 其中第一层的PE作用是将 X ∈ R M x N x 3 X \in R^{M {\rm x} N {\rm x} 3} XRMxNx3映射到 R M x N x d R^{M {\rm x} N {\rm x} d} RMxNxd, 然接下来,每个Stage开始之前都会有一个PE将上一个Stage输出的Feature map的Spatial dimension减小为原来的1/2, Channel dimension增大为原来的2倍。

总结

Reference

1.ArXiv, 2021, Focal Self-attention for Local-Global Interactions in Vision Transformers.

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MasterQKK 被注册

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值