此篇论文已被 AAAI 2022 收录,论文链接请见“阅读原文”。
● 简介 ●
近年来,以 DETR[1]为代表的基于 transformer 的端到端目标检测算法开始广受大家的关注。这类方法通过一组目标查询来推理物体与图像上下文的关系从而得到最终预测结果,且不需要 NMS 后处理,成为了一种目标检测的新范式。
但是,这类方法尚有一些不足之处。
首先,DETR 解码器的目标查询是一组可学习的向量。这组向量人类难以解释,没有显式的物理意义。同时,目标查询对应的预测结果的分布也没有明显的规律,这也导致模型较难优化。
为了解决上述问题,本文提出了一种基于锚点的查询设计,因此目标查询有了显式的物理意义,且每个查询仅关注对应锚点附近的区域,使得模型更容易优化。
此外,本文还提出了一种 attention 结构的变种,可以显著降低显存消耗,且对于检测任务中较难的 cross attention 依旧能保持精度不降。
如表 1 所示,最终本文算法比 DETR 精度更高,消耗显存更少,速度更快,且收敛更快(所需训练轮次更少)。
表1
● Attention 回顾 ●
首先,我们回顾一下 DETR 中 attention 的形式: ,,
这里 Q、K 和 V 分别为查询、键和值,下标 f 和 P 分别表示特征和位置编码向量,标量 为特征的维度。实际上,Q、K 和 V 还会分别经过一个全连接层,这里为了简洁省略了这部分。
DETR 的解码器包含两种 attention,一种是 self-attention,另一种是 cross-attention。
在 self-attention 中, 和 与 一样, 与 一样。其中 由上一个解码器层的输出得到,第一个解码器层的 被初始化为一个常数向量,如零向量;而 设置为一组可学的向量,为解码器中所有的 共享: ,
在 cross-attention 中, 由之前的 self-attention 的输出得到;而 和 是编码器的输出特征; 是编码器输出特征对应的位置编码向量,DETR 采用了正余弦函数来作为位置编码函数,我们将该位置编码函数记作 ,若编码器特征对应的位置记作 ,那么: 在此解释一下,H, W, C 分别是特征的高、宽和通道数目,而 是预设的目标查询数目。