mmdetection中的TOOD动态标签分配器代码注释tood.py-assign

论文 "TOOD: Task-aligned One-stage Object Detection" 

TaskAlignedAssigner类的主要功能是为每个预测框分配一个对应的真实框或背景。这个分配过程基于计算预测框与真实框之间的对齐度量(alignment metric),并选择与每个真实框对齐度最高的前k个预测框作为候选。

分配过程:

  1. 计算所有预测框与真实框之间的对齐度量,这是通过预测的分类分数分数和预测框与gt框的IoU计算来实现的。
  2. 为每个真实框选择对齐度最高的前k个预测框作为候选。
  3. 限制正样本的中心必须在真实框内,这是因为无锚点(anchor-free)检测器只能预测正距离。
  4. 计算候选预测框的中心与真实框边界的距离,并确定哪些候选预测框的中心实际位于真实框内。
  5. 最终,为每个预测框分配一个对应的真实框索引(正样本)或0(负样本),并返回分配结果。
def assign(self,
           pred_scores,  # 预测类别得分,形状为(n, num_classes),n是预测框的数量,num_classes是类别的数量
           decode_bboxes,  # 解码后的预测边界框,形状为(n, 4)
           anchors,  # 锚点或锚框,形状为(n, 4)
           gt_bboxes,  # 真实边界框,形状为(k, 4),k是真实框的数量也就是batchsize张图片中的目标数量
           gt_bboxes_ignore=None,  # 被忽略的真实边界框,例如,在COCO数据集中标记为“crowd”的框
           gt_labels=None,  # 真实边界框的标签,形状为(k, )
           alpha=1,  # 对齐度量中类别预测得分的权重
           beta=6):  # 对齐度量中IoU得分的权重
    """Assign gt to bboxes."""
    anchors = anchors[:, :4]  # 确保锚点坐标形状
    num_gt, num_bboxes = gt_bboxes.size(0), anchors.size(0)  # 获取真实框和预测框的数量
    # 计算所有预测框和真实框之间的IoU,使用.detach()来防止计算梯度
    overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach()
    bbox_scores = pred_scores[:, gt_labels].detach()  # 获取每个预测框对应的真实标签的得分
    # 默认情况下,所有预测框都分配给背景(索引0)
    assigned_gt_inds = anchors.new_full((num_bboxes, ), 0, dtype=torch.long)
    assign_metrics = anchors.new_zeros((num_bboxes, ))  # 初始化分配度量,用于存储每个预测框的分配得分

    # 如果没有真实框或预测框,直接返回空的分配结果
    if num_gt == 0 or num_bboxes == 0:
        max_overlaps = anchors.new_zeros((num_bboxes, ))
        if num_gt == 0:
            assigned_gt_inds[:] = 0  # 如果没有真实框,所有预测框都标记为背景
        if gt_labels is None:
            assigned_labels = None
        else:
            assigned_labels = anchors.new_full((num_bboxes, ), -1, dtype=torch.long)
        assign_result = AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
        assign_result.assign_metrics = assign_metrics
        return assign_result

    # 计算对齐度量,即类别预测得分和IoU得分的加权乘积
    alignment_metrics = bbox_scores**alpha * overlaps**beta
    topk = min(self.topk, alignment_metrics.size(0))  # 确定每个真实框的候选预测框数量
    _, candidate_idxs = alignment_metrics.topk(topk, dim=0, largest=True)  # 选择每个真实框的前k个预测框作为候选
    candidate_metrics = alignment_metrics[candidate_idxs, torch.arange(num_gt)]  # 获取这些候选框的对齐度量
    is_pos = candidate_metrics > 0  # 确定哪些候选框是正样本(对齐度量大于0)

    # 限制正样本的中心必须位于对应的真实框内部
    anchors_cx = (anchors[:, 0] + anchors[:, 2]) / 2.0
    anchors_cy = (anchors[:, 1] + anchors[:, 3]) / 2.0
    for gt_idx in range(num_gt):
        candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
ep_anchors_cx = anchors_cx.view(1, -1).expand(num_gt, num_bboxes).contiguous().view(-1)
    ep_anchors_cy = anchors_cy.view(1, -1).expand(num_gt, num_bboxes).contiguous().view(-1)
    candidate_idxs = candidate_idxs.view(-1)

    # 计算正样本预测框中心与真实框边界的距离
    l_ = ep_anchors_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
    t_ = ep_anchors_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
    r_ = gt_bboxes[:, 2] - ep_anchors_cx[candidate_idxs].view(-1, num_gt)
    b_ = gt_bboxes[:, 3] - ep_anchors_cy[candidate_idxs].view(-1, num_gt)
    is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01  # 确定中心是否在真实框内
    is_pos = is_pos & is_in_gts  # 更新正样本标记,只有中心在真实框内的候选框才是正样本

    # 如果一个预测框被分配给多个真实框,选择IoU最高的那个
    overlaps_inf = torch.full_like(overlaps, -INF).t().contiguous().view(-1)
    index = candidate_idxs.view(-1)[is_pos.view(-1)]
    overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
    overlaps_inf = overlaps_inf.view(num_gt, -1).t()

    max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
    assigned_gt_inds[max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1  # 分配给对应的真实框
    assign_metrics[max_overlaps != -INF] = alignment_metrics[max_overlaps != -INF, argmax_overlaps[max_overlaps != -INF]]  # 更新分配度量

    # 如果提供了真实框标签,则为分配的预测框设置对应的标签
    if gt_labels is not None:
        assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
        pos_inds = torch.nonzero(assigned_gt_inds > 0, as_tuple=False).squeeze()
        if pos_inds.numel() > 0:
            assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] - 1]
    else:
        assigned_labels = None
    assign_result = AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)  # 创建分配结果实例
    assign_result.assign_metrics = assign_metrics  # 保存分配度量
    return assign_result  # 返回分配结果

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
TOOD目标检测算法是一种用于目标检测的算法,它在MS-COCO数据集上进行了实验,并取得了很好的性能。该算法通过设计一个用于评价anchor对齐指标的值来优化目标检测的结果。这个指标通过预测类别的置信度和预测边界框的IoU来计算,同时通过调整参数α和β来控制两个任务对于对齐指标的影响。TOOD算法通过动态地关注任务对齐的anchor来提高目标检测的准确性。相比于其他单阶段检测算法,TOOD具有更高的AP指标,并且参数量和计算量更少。此外,TOOD还能更好地对目标分类和定位两个任务进行对齐。\[1\]\[2\]\[3\] #### 引用[.reference_title] - *1* [二维目标检测sota---TOOD任务对齐的一阶目标检测算法](https://blog.csdn.net/qq_41621517/article/details/122130470)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [51.1 AP!TOOD:刷新单阶段目标检测新纪录!ICCV 2021 Oral](https://blog.csdn.net/amusi1994/article/details/120073068)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [目标检测之TOOD:Task-aligned One-stage Object Detection](https://blog.csdn.net/qq_41950533/article/details/124094016)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值