动态标签分配 - 以 Nanodet-plus 中的代码为例

标签分配

部分内容参考自:https://www.bilibili.com/video/BV1ge41117va

简单介绍一些特点,主要结合动态标签分配的一个实例来看

从更高抽象的层面理解 assign

所有用于最终检测的特征图上的所有 point 都具备学习并预测目标的能力,在给定一幅图像及其目标 gt bbox 的情况下,为每个目标 gt bbox 选择恰当的特征图 point 进行学习预测的过程就是分配。

个人理解就是:

每个 anchor 锚点都有预测 bbox 的能力,对一张图像来说,将先验框 gt_bbox 与合适的锚点 points 进行匹配,训练 points 来预测。

一个锚点 point 分配给一个 gt_bbox (即标注框),但是一个 gt_bbox 可以和多个锚点 points 进行匹配

前提:

仅当感受野中心命中 gt bbox 的 point 才有可能被选用来预测这个 bbox

个人理解是:anchor 锚点要位于 gt_bbox 中才能被用来预测

两种类型的匹配机制:

基于规则的分配、自动分配(与网络的输出有关)

目标匹配是 One-stage Anchor-free 检测器核心中的核心!!!

实例分析:

NanoDet-Plus 使用的 DynamicSoftLabelAssigner 来分析,这里主要分析动态分配:

class DynamicSoftLabelAssigner(BaseAssigner):
    """Computes matching between predictions and ground truth with
    dynamic soft label assignment.
        使用动态软标签分配计算预测与真实值之间的匹配
    Args:
        topk (int): Select top-k predictions to calculate dynamic k
            best matchs for each gt. Default 13.
            为每个 gt 选择 k 个最佳预测来计算动态 k 最佳匹配。默认值为13
        iou_factor (float): The scale factor of iou cost. Default 3.0.
            IoU 代价的缩放因子。默认值为3.0
        ignore_iof_thr (int): whether ignore max overlaps or not.
            Default -1 (1 or -1).
            是否忽略最大重叠
    """

    def __init__(self, topk=13, iou_factor=3.0, ignore_iof_thr=-1):
        self.topk = topk
        self.iou_factor = iou_factor
        self.ignore_iof_thr = ignore_iof_thr

以 gt 开头的变量为真实标注信息

num_priors 即锚点 point 的数量,num_gts 即一幅图中真实标注的数量,其中 decoded_bboxes 为根据 preds 信息预测的 bboxes

将 point 和 gt 进行匹配,还是刚才所说的,一个 point 对应一个 gt,一个 gt 可以对应多个 point

简单总结一下过程(具体详细的内容看代码及注释):

以下用()包住的内容为张量尺寸大小,对于理解也十分有帮助


重点部分

首先,初步选出可能匹配的锚点,在所有锚点(num_priors)中选出 gt_bboxes 包住的锚点,即初步的有效锚点 (num_valid)

然后,计算代价矩阵 cost_matrix,以及 IoU 矩阵 pairwise_ious,均为**(num_valid, num_gts)**大小,即初步有效的锚点与真实标注的交叉矩阵

调用 dynamic_k_matching,根据 iou 排序,选出一个 gt_bbox 对应的 topk 个锚点,计算 iou 的和,将其作为 dynamic_k,将其作为该 gt_bbox 匹配的锚点个数(规定下限为1个,cost 最小的前 dynamic_k 个锚点),对每个 gt_bbox 均为同样的操作,如果存在一个锚点与多个 gt_bbox 匹配,则只保留代价最小的那一个 gt_bbox,并更新有效锚点为匹配了 gt_bbox 的锚点

最终得到锚点与 gt_bbox 的匹配,一个或多个有效锚点 priors 匹配一个 gt_bbox


更多的细节查看下方提供的代码即注释:

   def assign(
        self,
        pred_scores,		# [num_priors, num_classes]
        priors,				# [num_priors, 4]	 	[cx, cy, stride_x, stride_y]
        decoded_bboxes,		# [num_priors, 4]	 	[tl_x, tl_y, br_x, br_y]
        gt_bboxes,			# [num_gts, 4]		 	[tl_x, tl_y, br_x, br_y]
        gt_labels,			# [num_gts]
        gt_bboxes_ignore=None,
    ):
        INF = 100000000
        num_gt = gt_bboxes.size(0)
        num_bboxes = decoded_bboxes.size(0)

        # assign 0 by default
        # 创建一个与 decoded_bboxes 在同一设备上
        # 长度为 num_bboxes, 类型为 torch.long 的一维向量
        assigned_gt_inds = decoded_bboxes.new_full((num_bboxes,), 0, dtype=torch.long)

        # 锚点中心 (N, 2)
        prior_center = priors[:, :2]
        # (N, M, 2) <= (N, 1, 2) - (M, 2)   广播规则
        lt_ = prior_center[:, None] - gt_bboxes[:, :2]
        # 同上 (N, M, 2)
        rb_ = gt_bboxes[:, 2:] - prior_center[:, None]

        # 合并左上角和右下角的相对位置信息 (N, M, 4)
        deltas = torch.cat([lt_, rb_], dim=-1)
        # 判断 N个 锚点是否在 M个 gt_bboxes 内部, 得到 (N, M) 尺寸的向量
        is_in_gts = deltas.min(dim=-1).values > 0
        # (N, M) => (N, ), 对每个锚点, 判断是否有一个 gt_bboxes 包含它
        # valid_mask 表示了用于预测的锚点是否在 gt_bboxes 内部	(num_priors)
        valid_mask = is_in_gts.sum(dim=1) > 0

        # 获取有效锚点 (在gt_bboxes内部) 的对应的 preds 	(label以及bbox)
        # (num_valid, 4)
        valid_decoded_bbox = decoded_bboxes[valid_mask]
        # (num_valid, num_classes)
        valid_pred_scores = pred_scores[valid_mask]			
        # 被 gt_bboxes 包含的锚点数量 num_valid
        num_valid = valid_decoded_bbox.size(0)

        # 如果没有 gt_bboxes, 没有预测框或者没有有效匹配, 则直接返回空的分配结果
        if num_gt == 0 or num_bboxes == 0 or num_valid == 0:
            # No ground truth or boxes, return empty assignment
            max_overlaps = decoded_bboxes.new_zeros((num_bboxes,))	# 0
            if num_gt == 0:
                # No truth, assign everything to background
                assigned_gt_inds[:] = 0
            if gt_labels is None:
                assigned_labels = None
            else:
                assigned_labels = decoded_bboxes.new_full(
                    (num_bboxes,), -1, dtype=torch.long
                )
            return AssignResult(
                num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels
            )

        # 计算有效匹配的锚点预测的 bbox 与 gt_bboxes 之间的 IoU 矩阵
        # (num_valid, num_gts)  <- 	(num_valid, 4), (num_gts, 4)
        pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes)
        # 计算 IoU 的代价
        iou_cost = -torch.log(pairwise_ious + 1e-7)

        # 将真实的类别转换成 onehot 编码 (num_valid, num_gts, num_classes)
        gt_onehot_label = (
            F.one_hot(gt_labels.to(torch.int64), pred_scores.shape[-1])
            .float()
            .unsqueeze(0)				# 在第一个维度上增加一个维度
            .repeat(num_valid, 1, 1)	# 第一个维度重复 num_valid 次
        )
        # 赋值有效类别的分数  (num_valid, num_classes) ->
        # (num_valid, 1, num_classes) -> (num_valid, num_gts, num_classes)
        valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)

        # 生成软标签, 考虑 IoU 权重 (num_valid, num_gts, num_classes)
        soft_label = gt_onehot_label * pairwise_ious[..., None]
        # 软标签(真实标签 * IoU) - 预测得分
        scale_factor = soft_label - valid_pred_scores.sigmoid()

        # 使用二元交叉熵损失计算分类损失 (num_valid, num_gts, num_classes)
        cls_cost = F.binary_cross_entropy_with_logits(
            valid_pred_scores, soft_label, reduction="none"
        ) * scale_factor.abs().pow(2.0)
		
        # (num_valid, num_gts)
        cls_cost = cls_cost.sum(dim=-1)

        # 计算总代价, cls 代价 + bbox 代价 (num_valid, num_gts)
        cost_matrix = cls_cost + iou_cost * self.iou_factor
		# 时刻记着: valid 指的是在 gt_bboxes 内的锚点 point 的索引
        
        # 根据代价矩阵, iou矩阵, 均为 (num_valid, num_gts)
        # 进行动态 K-matching, 得到匹配的部分锚点, 这些锚点每个都对应一个 gt_bbox
        # 每个锚点分配给一个 gt_bbox, 一个 gt_bbox 可以对应多个锚点
        matched_pred_ious, matched_gt_inds = self.dynamic_k_matching(
            cost_matrix, pairwise_ious, num_gt, valid_mask
        )
        
        # convert to AssignResult format		
        # matched_pred_ious 为锚点预测的 bbox 与匹配的 gt_bbox 的 iou
        # matched_gt_inds 为锚点匹配的 gt_bbox 的索引
		# 分配的 gt_bbox 的索引, 未分配的为 0(初始值)
        assigned_gt_inds[valid_mask] = matched_gt_inds + 1
        # 分配的标签		(num_priors)
        assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)
        # 得到分配的类别 根据 gt 的索引确认对应的类别 	(num_priors)
        assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
        
        # 最大 IoU 	(num_priors)
        max_overlaps = assigned_gt_inds.new_full(
            (num_bboxes,), -INF, dtype=torch.float32
        )
        
        # 填入有效锚点对应的 IoU
        max_overlaps[valid_mask] = matched_pred_ious

        # 这里的判断默认情况下不会为 True
        if (
            self.ignore_iof_thr > 0					# 默认 -1 > 0
            and gt_bboxes_ignore is not None
            and gt_bboxes_ignore.numel() > 0
            and num_bboxes > 0
        ):
            ignore_overlaps = bbox_overlaps(
                valid_decoded_bbox, gt_bboxes_ignore, mode="iof"
            )
            ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
            ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
            assigned_gt_inds[ignore_idxs] = -1
		
        # 返回 num_gts, 锚点 priors 分配的 gt 索引以及对应的 IoU, 匹配的 gt 对应的标签
        return AssignResult(
            num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels
        )  
# 根据预测框与真实框之间 IoU 以及损失矩阵来进行匹配
    def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):
        """Use sum of topk pred iou as dynamic k. Refer from OTA and YOLOX.

        Args:
            cost (Tensor): Cost matrix.
            pairwise_ious (Tensor): Pairwise iou matrix.
            num_gt (int): Number of gt.
            valid_mask (Tensor): Mask for valid bboxes.
        """
        # 初始化一个与 cost 同形状的匹配矩阵 (num_valid, num_gts)
        matching_matrix = torch.zeros_like(cost)
        # select candidate topk ious for dynamic-k calculation
        candidate_topk = min(self.topk, pairwise_ious.size(0))
        # 选取每个真实框的前 topk 个最高 IoU 值     (candidate_topk, num_gt)
        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
        
        # calculate dynamic k for each gt
        # 计算每个 gt 的包含的锚点 points 的 IoU 最高的前 topk 个值的和, 作为动态 k
        # 即这里的 k 会根据前 topk 个 iou的值变化   (num_gts)
        dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
        # 进行动态匹配, 遍历 gt_bboxes
        # 对于每个 gt 挑选其中的 dymamic_k 个锚点进行匹配
        for gt_idx in range(num_gt):
            # 选取当前 gt_bbox 对应的损失矩阵中前 k 个最小损失值的索引, 即对应的锚点的索引
            # cost 维度为	 (num_priors, num_gts)
            _, pos_idx = torch.topk(
                cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False
            )
            # 将对应的匹配矩阵中的值置为 1, 对应的 dynamic_k 个锚点和该 gt 匹配 
            matching_matrix[:, gt_idx][pos_idx] = 1.0

        del topk_ious, dynamic_ks, pos_idx

        # matching_matrix 尺寸为 (num_priors, num_gt), 挑选出匹配的锚点    (num_priors)
        # 一个锚点与两个或更多个 gt_bbox 匹配	(num_priors)
        prior_match_gt_mask = matching_matrix.sum(1) > 1

        # 如果存在一个锚点和多个 gt_bboxes 匹配, 那么则选择代价最小的那一个
        if prior_match_gt_mask.sum() > 0:
            # 对于匹配多个 gt_bbox 的锚点, 选择代价最小的 gt_bbox 进行匹配		(num_priors) 
            cost_min, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1)
            # 将匹配多个 gt_bbox 的锚点的匹配清空, 选择代价最小的那一个 gt_bbox
            matching_matrix[prior_match_gt_mask, :] *= 0.0
            matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0

        # 匹配了 gt_bbox 的锚点矩阵  (num_priors) 	
        # get foreground mask inside box and center prior
        fg_mask_inboxes = matching_matrix.sum(1) > 0.0
        # 更新有效 mask, valid_mask 表示的为匹配了 gt_bbox 的锚点 (num_priors)
        valid_mask[valid_mask.clone()] = fg_mask_inboxes

        # 获取有效匹配的每个预测框对应的 gt_bbox 的索引 maching_matrix     (num_priors, num_gts)
        # argmax 获取最大的那一个值的索引      (num_valid_priros, num_gts) -> (num_valid_priors)
        matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
        # 计算有效匹配的每个预测框与 gt_bbox 之间的 IoU    
        # (num_valid, num_gts) -> (num_valid) -> (num_valid_priors)
        matched_pred_ious = (matching_matrix * pairwise_ious).sum(1)[fg_mask_inboxes]
        # 每个 prior 与一个 gt_bbox 对应 (一个 gt_bbox 可以对应多个预测框)
        return matched_pred_ious, matched_gt_inds
  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值