[Transformer] DAT: Vision Transformer with Deformable Attention

论文: https://arxiv.org/abs/2201.00520

代码: https://github.com/LeapLabTHU/DAT

2022年1月

1 简介

与CNN模型相比,基于Transformer的模型具有更大的感受野,擅长于建模长期依赖关系,在大量训练数据和模型参数的情况下取得了优异的性能。但是,计算成本较高,收敛速度较慢,过拟合的风险增加。

为了降低计算复杂度,

Swin Transformer采用基于Window的局部注意力来限制Local Window中的注意力;

Pyramid Vision Transformer(PVT)则通过对key和value特征映射进行降采样来节省计算量。

但是手工设计的注意力机制是数据不可知的。对于一个给定的 query,我们期望它的key/Value集合是灵活的,可以根据不同的输入进行调整。

Deformable Convolution Networks(DCN)的成功促使在Vision Transformer中探索一种可变形的注意力模式。由Deformable offsets引入的开销是patch数量的平方。因此,尽管最近的一些工作研究了变形机制的思想,但由于计算成本高,没有人将其作为构建强大的Backbone(如DCN)的基本构件。可变形机制要么在检测头中采用,要么作为预处理层对后续Backbone的patch进行采样。

本文提出了一种简单有效的可变形的自注意力模块,该模块以数据依赖的方式选择了自注意力中的key和value对的位置。这种灵活的方案使自注意力模块能够聚焦于相关区域并捕获更多信息。并在此模块上构造了一个强大的Pyramid Backbone,即可变形的注意力Transformer(Deformable Attention Transformer, DAT),用于图像分类和各种密集的预测任务。

不同于DCN在整个特征图上针对不同像素学习不同的offset,本文的方法学习几组与query无关的offset,将key和value移到重要区域(如图1(d)所示)。这是由于研究结果显示,全局注意力常常使得针对不同queries的注意力机制变得相同。这种设计既保留了线性复杂度,又为Transformer的主干引入了可变形的注意力模式。

具体来说:

对于每个注意力模块,首先将参考点生成为统一的网格,对于不同的输入数据 这些网格是相同的;

然后,offset网络将query特征作为输入,并为所有参考点生成相应的offset。这样一来,候选的key /value被转移到重要的区域,从而增强了原有的自注意力模块的灵活性和效率,从而捕获更多的信息特征。

Deformable Attention Transformer

2.1 Deformable Attention

在Transformer中实现DCN是一个重要的问题,在DCN中特征图上的每个元素分别学习其offset,其中H*W*C特征图上的3*3可变形卷积的空间复杂度为9HWC。如果直接在自注意力模块应用相同的机制,空间复杂度将急剧上升。

虽然Deformable DETR通过在每个检测头设置更少的key来减少这个计算开销,但是,在Backbone中,这样少的key是次要的,因为这样的信息丢失是不可接受的。

因此,本文选择 为每个query share shifted keys and values 以实现有效的权衡。

具体来说,本文提出了Deformable Attention:

从query中经过offset网络学习得到多组deformed sampling points,这些采样点决定了重要区域的位置。

采用双线性插值对特征图进行采样,然后将采样后的特征通过投影得到deformed Key/value。

最后,使用多头注意力。

另外,deformed points locations提供了一个更加有效的相对位置偏置,从而辅助deformable attention的学习。

2.2 Deformable attention module

如图2(a)所示,给定输入特征图x(H*W*C),生成一张均匀的点网格 p(HG*WG*2)。网格尺寸是从输入的特征图尺寸降采样得到的,HG=H/r, WG=W/r。参考点的值为线性间隔的2D坐标(0,0),…,(HG-1,WG-1),然后根据网格形状HG*WG将其归一化为范围[-1,+1],其中(-1,-1)代表左上角,(+1,+1)代表右下角。

为了获得每个参考点的offset,将输入特征图进行线性投影得到query token  q=xW{q},

然后输入一个轻量子网络θoffset(),生成偏移量\Delta p= \Theta offset\left (q \right )

用预定义的因子s来限制\Delta p的振幅,以防止太大的offset,

\Delta p\leftarrow stanh\left ( \Delta p \right )

然后在deformed points的位置进行特征采样,作为key和value:

2.3 offset生成

如前面所述,采用一个子网络进行Offset的生成,输入query,输出参考点的offset值。考虑到每个参考点覆盖一个局部的s×s区域(s是偏移的最大值),生成网络也应该有对局部特征的感知,以学习合理的offset。

  

因此,子网络包含2个具有非线性激活的卷积模块。

输入特征首先通过一个5×5的深度卷积来捕获局部特征。

然后,采用GELU激活和1×1卷积得到2D Offset。

2.4 Offset groups

将特征通道划分为G组。每个组的特征使用共享的子网络来生成相应的偏移量。在实际应用中,注意力模块的Head数M被设置为偏移组数G的倍数,确保多个注意力头被分配给一组deformed keys 和 values 。

2.5 Deformable 相对位置偏差

相对位置偏置对每对query和key之间的相对位置进行编码, 通过空间信息增强了注意力。

考虑到一个形状为H*W的特征图, 其相对坐标位移分别位于二维空间的[-H,H]和[-W,W]的范围内。

在Swin Transformer中, 构造了相对位置偏置表

通过在两个方向上的相对位移来索引表格,得到相对位置偏置B。

由于可变形注意力具有连续的key位置, 计算在归一化范围[-1,+1]内的相对位移 , 然后在相对偏置表中插值 , 以覆盖所有可能的偏移值。

2.6 计算复杂度

可变形多头注意力(DMHA)的计算成本与PVT或Swin Transformer中对应的计算成本相似。唯一的额外开销来自于用于生成偏移量的子网络。整个模块的复杂性可以概括为:

Ns为采样点数量。

offset网络的计算成本与通道数成线性关系,计算量相对较小。

此外,通过选择一个较大的下采样因子r , 复杂性将进一步降低,

有利于具有更高分辨率输入的任务, 如目标检测和实例分割。

2.7 模型架构

网络架构与PVT等具有相似的金字塔结构,广泛适用于需要多尺度特征图的各种视觉任务。

首先对形状为H×W×3的输入图像进行4×4不重叠的卷积嵌入,然后送入归一化层,得到H/4×W/4×C 的patch embeddings。为了构建一个层次特征金字塔,Backbone包括4个stage。

在2个连续的stage之间,有一个stride=2的2×2卷积进行下采样,使空间尺寸减半,并使特征尺寸翻倍。

在分类任务中,首先对最后一个stage输出的特征图进行归一化处理,然后采用具有合并特征的线性分类器来预测。

在目标检测、实例分割和语义分割任务中,DAT扮演着Backbone的作用,以提取多尺度特征。

为每个stage的特征添加一个归一化层,再将它们输入接下来的模块。

在DAT的第三和第四stage引入了连续的Local Attention和Deformable Attention Block。特征图首先通过基于Window的Local Attention进行处理,以局部聚合信息,然后通过Deformable Attention Block对局部增强token之间的全局关系进行建模。

由于前两个stage主要是学习局部特征,key和value具有较大的空间大小,大大增加了Deformable Attention的点积和双线性插值的计算开销。因此,为了实现模型容量和计算负担之间的权衡,这里只在第三和第四stage放置Deformable Attention,在前两stage使用Swin Transformer中的Shift Window Attention。

建立了不同参数和FLOPs的3个变体。

3 实验 

3.1 Classification

3.2 Detection 

将DAT作为RetinaNet、Mask R-CNN和Cascade Mask R-CNN的backbone:

 

3.3 Segmentation

 将DAT作为SemanticFPN和UperNet的backbone:

3.4 消融实验

首先评估了提出的可变形偏移量和可变形相对位置嵌入的有效性,如表6所示。

无论是在特征采样中采用偏移量,还是使用可变形的相对位置嵌入,都提供了+0.3的提升。

作者还尝试了其他类型的位置嵌入,包括固定的位置嵌入和深度卷积。

可变形相对位置嵌入更符合Deformable attention。

从表6中的第6行和第7行也可以看出,模型可以在前两个stage使用不同的注意力模块。

P表示使用SRA【36 Pyramid vision transformer】,S表示使用shifted-window attention。

用不同阶段的Deformable attention取代了Swin Transformer shift window attention。

替换最后两个阶段性能提高0.7。在早期阶段用更多Deformable attention代替会略微降低精度。

3.5 Visualization

可以看到采样点被移动到目标上了。

在左边一列中,变形的点被收缩成两个目标长颈鹿,而其他的点则是保持一个几乎均匀的网格和较小的偏移量。

在中间的一列中,变形点密集地分布在人的身体和冲浪板中。

右边的一列显示了变形点聚集于六个甜甜圈。

上述可视化表明,DAT可以学习到有意义的偏移量,以采样更好的注意力key,以提高各种视觉任务的表现。

  • 3
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值