DETR系列文章之–DAB DETR

本文介绍了ICLR2022的DAB DETR,它通过引入可学习的锚框改进了DETR系列模型,以适应不同尺寸的物体检测。模型结构包括空间注意力热图、详细结构、模型详解,特别是宽高调制的cross-attention模块。此外,还讨论了设置温度系数的重要性,并提供了代码讲解,涉及decoder及其layer的实现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

DETR系列文章之–DAB DETR



前言

介绍ICLR2022发表论文DAB-DETR论文基本思想即代码实现。DETR-Conditional DETR-DAB DETR
代码地址
论文地址
CSDN解读


一、模型结构

DAB-DETR认为原始DETR系列论文中:可学习的query仅仅给模型预测bbox提供了参考点(中心点)信息,没有提供box的宽和高信息。考虑引入可学习的锚框来使得模型能够自适应不同尺寸的物体。
在这里插入图片描述
object query中content query和key计算相似度完成特征提取,pos query用于限制提取区域的范围及大小

1.1 空间注意力热图可视化

可视化三个模型的空间注意力热图pk*pq,热图参考添加链接描述,DAB-DETR能够很好覆盖不同尺寸的物体。
在这里插入图片描述

1.2 详细结构

DAB-DETR直接预设N个可学习的anchor,类似SpareRCNN。然后经过宽高调制cross-attention模块,预测每个anchor box四个元素偏移量,更新anchor.
在这里插入图片描述

1.3 模型详解

首次设定N个可学习的4维anchors,然后通过PE和MLP映射成pq.
1)self-attn:常规自注意力,cq+pq
2)cross-attn:参考点(x,y)和Conditional DETR一样, Qq=cq拼接pq;
3)宽和高调制cross-attn模块,在计算pk和pq的权重相似度时引入(1/w,1/h)的尺度变换操作
在这里插入图片描述

1.4 设置温度系数

Detr中给特征图每个位置生成位置Pk完全使用的是Transformer中温度系数,而Transformer针对的是单词的嵌入向量设计的,而特征图中像素值大多分布在[0,1]之间,因此,贸然采用10000不合适,所以,本文采用了20。算是个trick吧,能涨一个点左右。
在这里插入图片描述

1.5 实验

比较四个不同的backbone: DETR-R50/R101; FPN-R50/R101;
在这里插入图片描述

二、代码讲解

2.1 decoder

整体decoder的forward函数部分:

def forward(self, tgt, memory,
            tgt_mask: Optional[Tensor] = None,
            memory_mask: Optional[Tensor] = None,
            tgt_key_padding_mask: Optional[Tensor] = None,
            memory_key_padding_mask: Optional[Tensor] = None,
            pos: Optional[Tensor] = None,
            refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 4
            ):
            
    # 第一层tgt初始化全0,output即输入的Cq!
    output = tgt 
	# 保存中间结果
    intermediate = []
    reference_points = refpoints_unsigmoid.sigmoid() # [300,batch,4]
    ref_points = [reference_points]

    # import ipdb; ipdb.set_trace()        

    for layer_id, layer in enumerate(self.layers):
    	# 取出anchor的中心Aq
        obj_center = reference_points[..., :self.query_dim]     # [num_queries, batch_size, 2]
        # 执行Pq = MLP(PE(obj_center)),将中心点转成256维度的嵌入向量
        query_sine_embed = gen_sineembed_for_position(obj_center)  
        query_pos = self.
### DAB-DETR API 文档及使用示例 #### 1. 接口概述 DAB-DETR 是一种改进版的目标检测框架,其核心在于将解码器中的 object query 建模成 anchor box 的四维坐标形式[^2]。此设计使得模型可以直接学习到物体的位置和大小信息。 #### 2. 主要模块说明 ##### 2.1 构造函数初始化参数 ```python class DAB_DETR(nn.Module): def __init__(self, backbone, transformer, num_classes, num_queries=100, aux_loss=False, position_embedding='sine', hidden_dim=256): super().__init__() self.backbone = backbone self.transformer = transformer ... ``` `num_queries`: 定义了解码器中使用的 object queries 数量,默认设置为 100。这些 queries 将被建模为具体的 anchor boxes。 ##### 2.2 数据预处理 在输入图像进入网络前,通常会经过一系列标准化操作: ```python from torchvision import transforms as T transform = T.Compose([ T.Resize((800, 800)), T.ToTensor(), ]) img_tensor = transform(image) ``` ##### 2.3 模型推理过程 完成数据准备之后,可以通过如下方式调用 `forward()` 方法来进行推断: ```python outputs = model(img_tensor.unsqueeze(0)) pred_logits = outputs['pred_logits'] # 预测类别概率分布 pred_boxes = outputs['pred_boxes'] # 预测边界框位置 (cx,cy,w,h)[^1] ``` 这里返回的结果包含了两个主要部分:一个是预测的分类得分;另一个则是预测出来的 bounding box 参数,其中心点 `(cx, cy)` 和宽高 `(w, h)` 已经过适当缩放以适应原始图片尺寸。 #### 3. 使用实例 下面是一个完整的例子来展示如何加载预训练权重并执行一次简单的对象检测任务: ```python import torch from models.dab_detr import build_model device = 'cuda' if torch.cuda.is_available() else 'cpu' # 加载预训练好的模型 model, criterion, postprocessors = build_model() checkpoint = torch.load('path_to_checkpoint.pth') model.load_state_dict(checkpoint['model']) model.to(device).eval() # 准备测试样本 image_path = "test_image.jpg" img_tensor = transform(Image.open(image_path)).to(device) with torch.no_grad(): output = model([img_tensor]) boxes = output[0]['boxes'].detach().cpu().numpy() scores = output[0]['scores'].detach().cpu().numpy() labels = output[0]['labels'].detach().cpu().numpy() ``` 上述代码片段展示了从加载预训练模型到实际运行整个流程的具体实现细节。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

猫撞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值