yolov6 TaskAlignedAssigner Task aligned 标签分配策略

该文详细介绍了YoloV6中TaskAlignedAssigner模块的工作原理,包括get_pos_mask、select_highest_overlaps和get_targets等关键方法,用于处理分类(cls)和回归(reg)任务的对齐学习。该模块在PyTorch框架下实现,涉及正样本选取、IoU计算和目标分配等步骤。
摘要由CSDN通过智能技术生成

yolov6采用的是anchor-free的方式,并学习TOOD使用Task-Alignment learning对齐cls与reg任务,下面是TaskAlignedAssigner的代码部分。 主要是三个方法:self.get_pos_mask、select_highest_overlaps和self.get_targets,最后返回target_labels(每个格子对应的标签), target_bboxes(每个格子对应的真实框,没有就用第一个替代,后面会使用mask去除), target_scores(每个格子对应类别的分数), fg_mask.bool()(正样本的mask)

class TaskAlignedAssigner(nn.Module):
    def __init__(self,
                 topk=13,
                 num_classes=80,
                 alpha=1.0,
                 beta=6.0, 
                 eps=1e-9):
        super(TaskAlignedAssigner, self).__init__()
        self.topk = topk
        self.num_classes = num_classes
        self.bg_idx = num_classes
        self.alpha = alpha
        self.beta = beta
        self.eps = eps

    @torch.no_grad()
    def forward(self,
                pd_scores,
                pd_bboxes,
                anc_points,
                gt_labels,
                gt_bboxes,
                mask_gt):
        """This code referenced to
           https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py

        Args:
            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
            anc_points (Tensor): shape(num_total_anchors, 2)
            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
            mask_gt (Tensor): shape(bs, n_max_boxes, 1)
        Returns:
            target_labels (Tensor): shape(bs, num_total_anchors)
            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
            fg_mask (Tensor): shape(bs, num_total_anchors)
        """
        self.bs = pd_scores.size(0)
        self.n_max_boxes = gt_bboxes.size(1)

        if self.n_max_boxes == 0:
            device = gt_bboxes.device
            return torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), \
                   torch.zeros_like(pd_bboxes).to(device), \
                   torch.zeros_like(pd_scores).to(device), \
                   torch.zeros_like(pd_scores[..., 0]).to(device)

        cycle, step, self.bs = (1, self.bs, self.bs) if self.n_max_boxes <= 100 else (self.bs, 1, 1)
        target_labels_lst, target_bboxes_lst, target_scores_lst, fg_mask_lst = [], [], [], []
        # loop batch dim in case of numerous object box
        for i in range(cycle):
            start, end = i*step, (i+1)*step
            pd_scores_ = pd_scores[start:end, ...]
            pd_bboxes_ = pd_bboxes[start:end, ...]
            gt_labels_ = gt_labels[start:end, ...]
            gt_bboxes_ = gt_bboxes[start:end, ...]
            mask_gt_   = mask_gt[start:end, ...]

            mask_pos, align_metric, overlaps = self.get_pos_mask(  #  获取正样本的mask 、分数 和GT与pred_bboxes的iou
                pd_scores_, pd_bboxes_, gt_labels_, gt_bboxes_, anc_points, mask_gt_)

            target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(  # 已经获得正样本的mask,但是GT存在交叠的情况,因此一个点可能对应多个GT
                mask_pos, overlaps, self.n_max_boxes)

            # assigned target
            target_labels, target_bboxes, target_scores = self.get_targets(
                gt_labels_, gt_bboxes_, target_gt_idx, fg_mask)

            # normalize
            align_metric *= mask_pos
            pos_align_metrics = align_metric.max(axis=-1, keepdim=True)[0]
            pos_overlaps = (overlaps * mask_pos).max(axis=-1, keepdim=True)[0]
            norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).max(-2)[0].unsqueeze(-1)
            target_scores = target_scores * norm_align_metric

            # append
            target_labels_lst.append(target_labels)
            target_bboxes_lst.append(target_bboxes)
            target_scores_lst.append(target_scores)
            fg_mask_lst.append(fg_mask)

        # concat
        target_labels = torch.cat(target_labels_lst, 0)
        target_bboxes = torch.cat(target_bboxes_lst, 0)
        target_scores = torch.cat(target_scores_lst, 0)
        fg_mask = torch.cat(fg_mask_lst, 0)

        return target_labels, target_bboxes, target_scores, fg_mask.bool()

1.self.get_pos_mask

        计算得出正样本的mask 、分数 和GT与pred_bboxes的iou这三个值。

 def get_pos_mask(self,
                     pd_scores,
                     pd_bboxes,
                     gt_labels,
                     gt_bboxes,
                     anc_points,
                     mask_gt):

        # get anchor_align metric
        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)  #  获取和GT与pred_bboxes的iou
        # get in_gts mask  在真实框中的anchor的mask
        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
        # get topk_metric mask  前13个分高的mask
        mask_topk = self.select_topk_candidates(
            align_metric * mask_in_gts, topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
        # merge all mask to a final mask
        mask_pos = mask_topk * mask_in_gts * mask_gt  # 同时满足这些条件 筛选出正样本

        return mask_pos, align_metric, overlaps

1.1 self.get_box_metrics 获取预测分数和GT与pred_bboxes的iou

         计算 bbox_scores和overlaps(iou) 使用类别分时和框相对应的公式计算出分数返回。

     def get_box_metrics(self,
                        pd_scores,
                        pd_bboxes,
                        gt_labels,
                        gt_bboxes):

        pd_scores = pd_scores.permute(0, 2, 1)  # 32 3 8400
        gt_labels = gt_labels.to(torch.long)  # 32 22 1
        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2 32 22
        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)
        ind[1] = gt_labels.squeeze(-1)
        bbox_scores = pd_scores[ind[0], ind[1]]  # 32 3 8400 -> 32 31 8400

        overlaps = iou_calculator(gt_bboxes, pd_bboxes)  # 计算真实框与预测框的iou 32 22 8400
        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)  # 这是将分类的分数和框对应的公式

1.2 select_candidates_in_gts 计算在真实框中的格子,返回mask

        这个主要是使用真实框的右下角减去每个格子的中心点以及中心点减去真实框的左上角,如果这个值的都大于0 说明他是在真实框中。

def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
    """select the positive anchors's center in gt

    Args:
        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
    Return:
        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
    """
    n_anchors = xy_centers.size(0)
    bs, n_max_boxes, _ = gt_bboxes.size()
    _gt_bboxes = gt_bboxes.reshape([-1, 4])  # 32 31 4 -> 992 4
    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)  # 992 8400 2  中心点复制多少份 好与上面的xy_centers相比较
    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)  # 左上角复制8400份
    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)  # 右上角复制8400份
    b_lt = xy_centers - gt_bboxes_lt  # +  中心点-左上角
    b_rb = gt_bboxes_rb - xy_centers  # +  右下角-中心点
    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)  # 都是正的说明这个格子也就是预测框的中心点在真实框中

1.3 self.select_topk_candidates 前13个分高的mask,返回8400中这13个位置。

    def select_topk_candidates(self,
                               metrics,
                               largest=True,
                               topk_mask=None):

        num_anchors = metrics.shape[-1]  # 8400 
        topk_metrics, topk_idxs = torch.topk(  # 找出前13个分最高的 并返回ID 32 22 13
            metrics, self.topk, axis=-1, largest=largest)
        if topk_mask is None:
            topk_mask = (topk_metrics.max(axis=-1, keepdim=True) > self.eps).tile(
                [1, 1, self.topk])
        topk_idxs = torch.where(topk_mask, topk_idxs, torch.zeros_like(topk_idxs))  # 将最大框数目中没有的去掉 
        is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2)  # onehot编码
        is_in_topk = torch.where(is_in_topk > 1,
            torch.zeros_like(is_in_topk), is_in_topk)
        return is_in_topk.to(metrics.dtype)

2. select_highest_overlaps 

        已经获得正样本的mask,但是真实框存在交叠的情况,也就是一个格子可能对应多个真实框。先将每一个真实框对应的8400和在一起 32 22 8400 -> 32 8400 判断其中有没有重复的,有重复的则取他们的iou值大的作为这个点的真实框,更新mask_pos,返回mask_pos、fg_mask和target_gt_idx(32x8400 真实框在8400中对应的ID索引)

def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
    """if an anchor box is assigned to multiple gts,
        the one with the highest iou will be selected.

    Args:
        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
    Return:
        target_gt_idx (Tensor): shape(bs, num_total_anchors)
        fg_mask (Tensor): shape(bs, num_total_anchors)
        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
    """
    fg_mask = mask_pos.sum(axis=-2)  # 一张图 对应真实框的mask 合并成一张图一个mask 32 22 8400 -> 32 8400
    if fg_mask.max() > 1:
        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1])  # 把上一个得出的一张图的mask中大于1的变成TRUE 然后复制出真实框的数量
        max_overlaps_idx = overlaps.argmax(axis=1)   # overlaps: 真实框与预测框计算的iou 返回8400当中每一个格子中哪一个真实框与预测框计算的iou最大的真实框的索引 一个格子对应多个真实框则会取最大的那个真实框 32X8400
        is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes)  # 用 onehot 再一次编码mask
        is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) # shape转换
        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos)  # 有多个真实框就用is_max_overlaps 反之mask_pos 更新mask_pos
        fg_mask = mask_pos.sum(axis=-2)
        # print(fg_mask[0])
    target_gt_idx = mask_pos.argmax(axis=-2)
    # print(target_gt_idx[0])
    return target_gt_idx, fg_mask , mask_pos

3 . self.get_targets

        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes 

                batch_ind * self.n_max_boxes:每一张图片的第一个索引,0代表第一张 22代表第二张(最大真实框的数目),44代表第三张,将他们加上原来的target_gt_idx(目标所对应的真实框的ID)就如图下所示。例如第一行 0,0,0,0,0,2,...,5,0,0,0,1,..,0,0,0; 第二行:22,22,22,22,23,22,22,...,28,22,22,22,22。这个也就对应着那一张图片的哪一个真实框。

target_gt_idx

 将gt_labels平摊后(32x22=704,682+22=704 685就是最后一张图片的索引,加上22个真实框就是704),正好对应着上面target_gt_idx的索引。得出8400的格子分别对应的类别和预测框是哪一个。

 

    def get_targets(self,
                    gt_labels,
                    gt_bboxes,
                    target_gt_idx,
                    fg_mask):

        # assigned target labels
        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[...,None]  # 32X1 (0 1 2 3 ... 32)
        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # 32 8400; batch_ind * self.n_max_boxes:这个的作用是每一张图片的第一个索引,0代表第一张 22代表第二张(最大真实框的数目)
        target_labels = gt_labels.long().flatten()[target_gt_idx]  # 8400中的格子对应的类别是哪一个

        # assigned target boxes
        target_bboxes = gt_bboxes.reshape([-1, 4])[target_gt_idx]  # 8400中的格子对应的预测框是哪一个

        # assigned target scores
        target_labels[target_labels<0] = 0
        target_scores = F.one_hot(target_labels, self.num_classes)  # onehot 对应类别 
        fg_scores_mask  = fg_mask[:, :, None].repeat(1, 1, self.num_classes)
        target_scores = torch.where(fg_scores_mask > 0, target_scores,
                                        torch.full_like(target_scores, 0))

        return target_labels, target_bboxes, target_scores

  • 4
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值