Anchor DETR 论文笔记

Anchor DETR: Query Design for Transformer-Based Object Detection

论文连接:https://arxiv.org/abs/2109.07107v2
源码连接:https://github.com/megvii-research/AnchorDETR

在以前的基于Transformer的检测器中,对象查询是一组可学习的embeddings

然而,每一个学习到的嵌入都没有明确的物理意义,我们不能解释它将集中在哪里
由于每个对象查询的预测插槽都没有特定的模式,因此很难进行优化
换句话说,每个对象查询都不会关注特定的区域

为了解决这些问题,在Anchor DETR中,对象查询是基于锚点的,这在基于cnn的检测器中被广泛使用
因此,每个对象查询都集中在锚点附近的对象上

Anchor DETR的查询设计可以在一个位置上预测多个对象来解决困难:“一个区域,多个对象”

同时设计了一种Attention变体,它可以降低内存使用,同时获得与DETR中的标准的Attention相似或更好的性能

Anchor DETR训练50个epochs即可,相比于DETR也有10x左右的速度提升(论文中并未给出准确的速度提升,论文中提到的是10x fewer epochs)


下图的a是在DETR论文中就给出过的一张可视化图
每一个点表示预测框体的中心点相对于图像的相对坐标(图像的尺寸被归一化成1x1)
绿色点表示小目标,红色表示水平形状(长方形)的框体,蓝色表示垂直类型的
每个slot都有偏好的图像区域以及目标的大小和形状

DETR中每个对象查询的预测都与不同的区域相关,每个对象查询都将负责一个非常大的区域。这种位置歧义,即对象查询不关注特定的区域,使得其难以优化

下图b是Anchor DETR的slot可视化
每个对象查询的三个模式的所有预测都分布在相应的锚点周围
每一行表示了一种pattern
最下面一行是参考点
换句话说,它演示了每个对象查询只关注相应锚点附近的对象。因此,对象查询可以很容易地解释

由于对象查询具有特定的模式,并且不需要预测远离相应位置的对象,因此网络可以更容易地进行优化

请添加图片描述
DETR学习到的对象查询很难被解释。它没有明确的物理意义,每个对象查询的相应预测槽也没有特定的模式

DETR中每个对象查询的预测都与不同的区域相关,每个对象查询都将负责一个非常大的区域。这种位置歧义,即对象查询不关注特定的区域,使得其难以优化

Anchor DETR
对象查询是对锚点坐标的编码,因此每个对象查询都具有显式的物理意义

然而,这个解决方案将遇到困难:多个对象在同一个位置附近。在这种情况下,此位置的一个对象查询不能预测多个对象,因此来自其他位置的对象查询必须协同预测这些对象
它将使每个对象查询负责一个更大的区域
因此,我们通过向每个锚点添加多个模式来改进对象查询设计,以便每个锚点都可以预测多个对象

由于对象查询具有特定的模式,并且不需要预测远离相应位置的对象,因此可以更容易地进行优化


Method

Anchor Points

在基于cnn的检测器中,锚点始终是特征图的对应位置。但它在基于Transformer的探测器中可以更灵活

锚点可以是learned points、均匀网格点或其他手工设计锚点

Anchor DETR尝试了两种点
一个是网格锚点,另一个是learned points

请添加图片描述

  1. 网格锚点被固定为图像中的均匀网格点
  2. learned points以0到1的均匀分布随机初始化,并更新为学习参数

Attention Formulation

Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T d k ) V Q = Q f + Q p , K = K f + K p , V = V f , \begin{gathered} \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V \\ Q=Q_f+Q_p, K=K_f+K_p, V=V_f, \end{gathered} Attention(Q,K,V)=softmax(dk QKT)VQ=Qf+Qp,K=Kf+Kp,V=Vf,

Q K V在相互计算之前,都会先经过一次Linear

DETR的decoder有两种attention,一个是self-attention,另一个是cross-attention
decoder的position embedding Qp 是一个learned embedding

Q p =  Embedding  ( N q , C ) Q_p=\text { Embedding }\left(N_q, C\right) Qp= Embedding (Nq,C)

Anchor Points to Object Query

通常,解码器中的Qp被视为对象查询,因为它负责区分不同的对象

本文提出了基于锚点 Pos ⁡ q \operatorname{Pos}_q Posq 的对象查询设计方法。
Pos ⁡ q ∈ R N A × 2 \operatorname{Pos}_q \in \mathbb{R}^{N_A \times 2} PosqRNA×2 表示Na个点,其位置为(x,y),范围为0到1
然后,基于锚点的对象查询可以表示为:

Q p = Encode ⁡ ( Pos ⁡ q ) Q_p=\operatorname{Encode}\left(\operatorname{Pos}_q\right) Qp=Encode(Posq)

这意味着我们将锚定点编码为对象查询

那么如何设计编码函数呢?由于对象查询被设计为查询位置嵌入,因此最自然的方法是使用key的位置编码功能

Q p = g ( Pos ⁡ q ) , K p = g ( Pos ⁡ k ) Q_p=g\left(\operatorname{Pos}_q\right), K_p=g\left(\operatorname{Pos}_k\right) Qp=g(Posq),Kp=g(Posk)
这里的g是位置编码函数
在本文中,我们不仅仅使用启发式 g s i n g_{s i n} gsin ,并且使用具有两个线性层的小MLP网络来适应它

Multiple Predictions for Each Anchor Point

为了处理一个位置可能有多个对象的情况,进一步改进了对象查询,以便为每个锚点预测多个对象,而不是仅仅是一个预测

DETR的object query 为 Q f i n i t ∈ Q_f^{i n i t} \in Qfinit R N q × C \mathbb{R}^{N_q \times C} RNq×C
一共有 N q N_q Nq 个 object queries,每一个为 Q f i ∈ Q_f^i \in Qfi R 1 × C \mathbb{R}^{1 \times C} R1×C

这里在多一个pattern embedding, Anchor DETR的默认设置 Np=3,如下:

Q f i = Embedding ⁡ ( N p , C ) Q_f^i=\operatorname{Embedding}\left(N_p, C\right) Qfi=Embedding(Np,C)

Row-Column Decoupled Attention

Transformer将花费大量的GPU内存,这可能会限制其对高分辨率特性或其他扩展的使用
Deformable DETR可以降低内存成本,但它将导致内存的随机访问,这可能对现代的大规模并行性加速器不友好

还有一些注意模块具有线性复杂度,不会导致内存的随机访问。然而,在实验中,发现这些注意模块不能很好地处理类似detr的探测器
这可能是因为解码器中的交叉注意比自我注意要困难得多

这里提出了行列分解Attention,不仅可以降低内存使用,还可以达到比DETR相似或者更好的性能

RCDA的核心是分解2D的key feature ( K f ∈ R H × W × C K_f \in \mathbb{R}^{H \times W \times C} KfRH×W×C)为 1D的 行特征 ( K f , x ∈ R W × C K_{f, x} \in \mathbb{R}^{W \times C} Kf,xRW×C)以及1D的列特征 ( K f , y ∈ R H × C K_{f, y} \in \mathbb{R}^{H \times C} Kf,yRH×C),然后在行和列上分别执行自注意力

默认情况下,我们选择一维全局平均池来解耦关键特征
在不失去一般性的前提下,我们假设W≥H

那么RCDA可以表示为:

请添加图片描述
weighted_sumW 和 weighted_sumH 运算分别沿宽度维度和高度维度进行加权和

现在我们来分析一下为什么它可以保存内存。
在前面的公式中,我们假设多头注意力的头部数M为1,而不失一般性,但我们应该考虑头部数M进行内存分析

在DETR中主要的内存消耗是 A ∈ R N q × H × W × M A \in \mathbb{R}^{N_q \times H \times W \times M} ARNq×H×W×M

在RCDA中的weight map是 A x ∈ R N q × W × M A_x \in \mathbb{R}^{N_q \times W \times M} AxRNq×W×M A y ∈ R N q × H × M A_y \in \mathbb{R}^{N_q \times H \times M} AyRNq×H×M
这两者一定会比A小
并且,RCDA中的主要内存成本是临时结果Z

因此,主要比较A和Z

r = ( N q × H × W × M ) / ( N q × H × C ) = W × M / C \begin{aligned} r & =\left(N_q \times H \times W \times M\right) /\left(N_q \times H \times C\right) \\ & =W \times M / C \end{aligned} r=(Nq×H×W×M)/(Nq×H×C)=W×M/C

一般情况下 M=8,C=256
因此,当长边W等于32时,内存成本大致相同,这是目标检测中C5特征的典型值

而当使用高分辨率功能时,它可以节省内存,例如为C4或DC5功能节省大约2倍内存,为C3功能节省4倍内存

模型结构

替换encoder中的self-attention,以及decoder中的cross-attention 为RCDA attention

请添加图片描述


Experiment

请添加图片描述

Ablation Study

Effectiveness of each component

请添加图片描述

Multiple Predictions for Each Anchor Point

请添加图片描述

Anchor Points Types

请添加图片描述

Row-Column Decoupled Attention

请添加图片描述
与标准注意模块相比,具有线性复杂度的注意模块可以显著降低训练记忆。然而,他们的性能显著下降
这可能是因为在类detr探测器中的交叉注意比自我注意要困难得多
相反,行列解耦注意力也可以达到类似的性能。如前所述,行-列解耦注意可以显著减少内存,并在使用C5特性时获得大致相同的内存成本

例如,RCDA DC5将训练记忆从10.5G减少到4.4G,并在使用DC5功能时获得了与标准注意力相同的性能。

Prediction Slots of Object Query

请添加图片描述

  1. 大的box通常出现在模式a
  2. 小的box通常出现在模式b
  3. 模式c比较均衡
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值