YOLOX算法笔记

本文为个人学习过程中所记录笔记,便于梳理思路和后续查看用,如有错误,感谢批评指正!

code:https://github.com/Megvii-BaseDetection/YOLOX
paper:https://arxiv.org/abs/2107.08430

参考:
【1】从零开始实现yolox四:模型的训练(一)损失函数与标签分配
1、BACKBONE
默认采用yolov3的darknet53作为backbone,其中采用backbone的三个不同分辨率输出的分支C3, C4, C5.
2、NECK
yolox中采用了PAFPN作为neck,其与FPN相比结构稍为复杂,但总体思想还是多尺度特征的融合。整体结构草图如下:
在这里插入图片描述
3、HEAD
yolox中采用解耦头,将回归和分类分支进行解耦。具体结构如图所示:
在这里插入图片描述

4、数据集增强
4.1 mosaic
数据集增强顺序为mosaic+mixup,即1组mosaic图片和一张常规图片进行一次mixup。

训练代码分析
模型中采用解耦头,针对三种不同分辨率,每一种分辨率输出特征图均将分类和回归解耦,一共有box头,回归头,置信度头(暂且这样称呼吧)。然后将三个头concat,最后将三种尺度的进一步concat。

if self.training:
                output = torch.cat([reg_output, obj_output, cls_output], 1)
                output, grid = self.get_output_and_grid(
                    output, k, stride_this_level, xin[0].type()
                )
  #输出的处理                          
def get_output_and_grid(self, output, k, stride, dtype):
       grid = self.grids[k]  #[0, 0, 0]
       batch_size = output.shape[0]
       n_ch = 5 + self.num_classes
       hsize, wsize = output.shape[-2:] #h,w
       if grid.shape[2:4] != output.shape[2:4]: #进入分支
           yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])  #生成网格,类似于位置编码
           grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) #1*1*80*80*2
           self.grids[k] = grid ##1*1*80*80*2

       output = output.view(batch_size, 1, n_ch, hsize, wsize)
       output = output.permute(0, 1, 3, 4, 2).reshape(
           batch_size, hsize * wsize, -1
       ) #batch * 6400 * n_ch
       grid = grid.view(1, -1, 2) #1,80*80,2
       output[..., :2] = (output[..., :2] + grid) * stride   #前两列box的左上角坐标加入grid乘以stride
       output[..., 2:4] = torch.exp(output[..., 2:4]) * stride  #2,3列box右下角坐标归一化然后乘以stride
       return output, grid

loss计算:

def get_losses(
        self,
        imgs,
        x_shifts,
        y_shifts,
        expanded_strides,
        labels,
        outputs,
        origin_preds,
        dtype,
    ):
        bbox_preds = outputs[:, :, :4]  # [batch, n_anchors_all, 4]
        obj_preds = outputs[:, :, 4:5]  # [batch, n_anchors_all, 1]
        cls_preds = outputs[:, :, 5:]  # [batch, n_anchors_all, n_cls]

        # calculate targets
        nlabel = (labels.sum(dim=2) > 0).sum(dim=1)  # number of objects

        total_num_anchors = outputs.shape[1]
        x_shifts = torch.cat(x_shifts, 1)  # [1, n_anchors_all]
        y_shifts = torch.cat(y_shifts, 1)  # [1, n_anchors_all]
        expanded_strides = torch.cat(expanded_strides, 1)
        if self.use_l1:
            origin_preds = torch.cat(origin_preds, 1)

        cls_targets = []
        reg_targets = []
        l1_targets = []
        obj_targets = []
        fg_masks = []

        num_fg = 0.0
        num_gts = 0.0
        for batch_idx in range(outputs.shape[0]):
            num_gt = int(nlabel[batch_idx]) #当前图片中gt数目
            num_gts += num_gt
            if num_gt == 0: #如果当前图片gt数目为0,就新建几个空张量
                cls_target = outputs.new_zeros((0, self.num_classes)) #类别
                reg_target = outputs.new_zeros((0, 4))  #回归框
                l1_target = outputs.new_zeros((0, 4))   #
                obj_target = outputs.new_zeros((total_num_anchors, 1)) #置信度
                fg_mask = outputs.new_zeros(total_num_anchors).bool() #能和gt匹配的预测框的索引,这里则全为false
            else:
                gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5] #真值框坐标
                gt_classes = labels[batch_idx, :num_gt, 0]  #真值类别
                bboxes_preds_per_image = bbox_preds[batch_idx]  #预测框

                try:
                    (
                        gt_matched_classes,
                        fg_mask,
                        pred_ious_this_matching,
                        matched_gt_inds,
                        num_fg_img,
                    ) = self.get_assignments(  # noqa
                        batch_idx,
                        num_gt,
                        gt_bboxes_per_image,
                        gt_classes,
                        bboxes_preds_per_image,
                        expanded_strides,
                        x_shifts,
                        y_shifts,
                        cls_preds,
                        obj_preds,
                    )
                except RuntimeError as e:
                    # TODO: the string might change, consider a better way
                    if "CUDA out of memory. " not in str(e):
                        raise  # RuntimeError might not caused by CUDA OOM

                    logger.error(
                        "OOM RuntimeError is raised due to the huge memory cost during label assignment. \
                           CPU mode is applied in this batch. If you want to avoid this issue, \
                           try to reduce the batch size or image size."
                    )
                    torch.cuda.empty_cache() #释放显存,有文档说该句可以省略,解释为:'''当显存中的数据没有任何变量引用时,会自动释放显存,但释放的显存在Nvidia中看不到,只有加上这一句,才会在Nvidia-smi中释放'''                    
                  (	gt_matched_classes,
                      fg_mask,
                      pred_ious_this_matching,
                      matched_gt_inds,
                      num_fg_img,
                    ) = self.get_assignments(  # noqa
                        batch_idx,
                        num_gt,
                        gt_bboxes_per_image,
                        gt_classes,
                        bboxes_preds_per_image,
                        expanded_strides,
                        x_shifts,
                        y_shifts,
                        cls_preds,
                        obj_preds,
                        "cpu",
                    )

                torch.cuda.empty_cache()
                num_fg += num_fg_img

                cls_target = F.one_hot(
                    gt_matched_classes.to(torch.int64), self.num_classes
                ) * pred_ious_this_matching.unsqueeze(-1)
                obj_target = fg_mask.unsqueeze(-1)
                reg_target = gt_bboxes_per_image[matched_gt_inds]
                if self.use_l1:
                    l1_target = self.get_l1_target(
                        outputs.new_zeros((num_fg_img, 4)),
                        gt_bboxes_per_image[matched_gt_inds],
                        expanded_strides[0][fg_mask],
                        x_shifts=x_shifts[0][fg_mask],
                        y_shifts=y_shifts[0][fg_mask],
                    )

            cls_targets.append(cls_target)
            reg_targets.append(reg_target)
            obj_targets.append(obj_target.to(dtype))
            fg_masks.append(fg_mask)
            if self.use_l1:
                l1_targets.append(l1_target)

        cls_targets = torch.cat(cls_targets, 0)
        reg_targets = torch.cat(reg_targets, 0)
        obj_targets = torch.cat(obj_targets, 0)
        fg_masks = torch.cat(fg_masks, 0)
        if self.use_l1:
            l1_targets = torch.cat(l1_targets, 0)

        num_fg = max(num_fg, 1)
        loss_iou = (
            self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)
        ).sum() / num_fg
        loss_obj = (
            self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)
        ).sum() / num_fg
        # loss_obj = (
        #     self.focal_loss(obj_preds.sigmoid().view(-1, 1), obj_targets)
        # ).sum() / num_fg #cheng
        loss_cls = (
            self.bcewithlog_loss(
                cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets
            )
        ).sum() / num_fg #cheng 原本是self.bcewithlog_loss()
        if self.use_l1:
            loss_l1 = (
                self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)
            ).sum() / num_fg
        else:
            loss_l1 = 0.0

        reg_weight = 5.0
        loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1

        return (
            loss,
            reg_weight * loss_iou,
            loss_obj,
            loss_cls,
            loss_l1,
            num_fg / max(num_gts, 1),
        )

标签分配:
具体步骤如下:
1、通过几何约束筛选第一遍anchor,具体函数见get_geometry_constraint,在每个gt中心画一个边长为3的正方形框,然后anchor中心在该正方形内的便是正样本。
2、将第一遍筛选得到的正样本继续采用simOTA算法进行筛选。见函数get_assignments和simota_matching。具体地,将第一遍筛选得到的正样本计算得到一个cost矩阵。按照iou排序,得到前十个框,然后将是个iou值相加取整,作为动态k个候选框(小于1的取值1)。同时对单个anchor对应多个gt的情况进行处理。得到最终动态分配正样本的结果。

@torch.no_grad()
def get_assignments(
        self,
        batch_idx,
        num_gt,
        gt_bboxes_per_image,
        gt_classes,
        bboxes_preds_per_image,
        expanded_strides,
        x_shifts,
        y_shifts,
        cls_preds,
        obj_preds,
        mode="gpu",
    ):

        if mode == "cpu":
            print("-----------Using CPU for the Current Batch-------------")
            gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
            bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
            gt_classes = gt_classes.cpu().float()
            expanded_strides = expanded_strides.cpu().float()
            x_shifts = x_shifts.cpu()
            y_shifts = y_shifts.cpu()

        fg_mask, geometry_relation = self.get_geometry_constraint(
            gt_bboxes_per_image,
            expanded_strides,
            x_shifts,
            y_shifts,
        )  #几何约束,以gt中心为中心,边长为3的正方形,看预测框中心是否在这个正方形内,进行过滤。
		# bboxes_preds_per_image.shape,  [n_anchors_all, 4]
        bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
        cls_preds_ = cls_preds[batch_idx][fg_mask]
        obj_preds_ = obj_preds[batch_idx][fg_mask]
        num_in_boxes_anchor = bboxes_preds_per_image.shape[0]

        if mode == "cpu":
            gt_bboxes_per_image = gt_bboxes_per_image.cpu()
            bboxes_preds_per_image = bboxes_preds_per_image.cpu()

        pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)

        gt_cls_per_image = (
            F.one_hot(gt_classes.to(torch.int64), self.num_classes)
            .float()
        )
        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)

        if mode == "cpu":
            cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()

        with torch.cuda.amp.autocast(enabled=False):
            cls_preds_ = (
                cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()
            ).sqrt()
            pair_wise_cls_loss = F.binary_cross_entropy(
                cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1),
                gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1),
                reduction="none"
            ).sum(-1)
        del cls_preds_

        cost = (
            pair_wise_cls_loss
            + 3.0 * pair_wise_ious_loss
            + float(1e6) * (~geometry_relation)
        ) #第一次得出的anchor计算分类损失和iou损失得到cost

        (
            num_fg,
            gt_matched_classes,
            pred_ious_this_matching,
            matched_gt_inds,
        ) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss

        if mode == "cpu":
            gt_matched_classes = gt_matched_classes.cuda()
            fg_mask = fg_mask.cuda()
            pred_ious_this_matching = pred_ious_this_matching.cuda()
            matched_gt_inds = matched_gt_inds.cuda()

        return (
            gt_matched_classes,
            fg_mask,
            pred_ious_this_matching,
            matched_gt_inds,
            num_fg,
        )

几何约束:

def get_geometry_constraint(
        self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts,
    ):
        """
        Calculate whether the center of an object is located in a fixed range of
        an anchor. This is used to avert inappropriate matching. It can also reduce
        the number of candidate anchors so that the GPU memory is saved.
        """
        expanded_strides_per_image = expanded_strides[0]
        x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0) 
        y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0) 
        #计算映射回原图的中心点的坐标,单位1加上0.5,
        #便是中心点,然后乘以对应的下采样倍数,乘号两边维度均为【8400】,结果为【1*8400】

        # in fixed center
        center_radius = 1.5
        center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius #torch.Size([1, 8400])
        #gt_bboxes_per_image,gt的中心以及宽高
        #以gt为中心画一个边长为3的正方形, gt_bboxes_per_image格式为中心点坐标加上宽高
        # print(gt_bboxes_per_image.shape) #[num_gt, 4] 
        # print(gt_bboxes_per_image[:, 0:1].shape)   torch.Size([num_gt, 1])
        gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist  #[num_gt, num_anchor]
        gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist
        gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist
        gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist
		
        c_l = x_centers_per_image - gt_bboxes_per_image_l
        c_r = gt_bboxes_per_image_r - x_centers_per_image
        c_t = y_centers_per_image - gt_bboxes_per_image_t
        c_b = gt_bboxes_per_image_b - y_centers_per_image
        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) #[num_gt,num_anchor,4]
        is_in_centers = center_deltas.min(dim=-1).values > 0.0 #[num_gt, num_anchor]
        anchor_filter = is_in_centers.sum(dim=0) > 0 #剔除针对每个gt,gt中心值都不在框内的anchor,size为[num_anchor]
        geometry_relation = is_in_centers[:, anchor_filter]

        return anchor_filter, geometry_relation

simOTA算法:

def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
        # Dynamic K
        # ---------------------------------------------------------------
        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)

        n_candidate_k = min(10, pair_wise_ious.size(1))
        topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1) #找出最多前十个iou最匹配的anchor [:, 10]
        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) #将输入input张量每个元素的范围限制到区间 [min,max],
        #返回结果到一个新张量;十个iou求和取整,即做为当前gt匹配得到的iou数量
        for gt_idx in range(num_gt):
            _, pos_idx = torch.topk(
                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
            )
            matching_matrix[gt_idx][pos_idx] = 1

        del topk_ious, dynamic_ks, pos_idx

        anchor_matching_gt = matching_matrix.sum(0)
        # deal with the case that one anchor matches multiple ground-truths
        if anchor_matching_gt.max() > 1: #处理一个anchor对应多个gt的情况
            multiple_match_mask = anchor_matching_gt > 1 #找出单个anchor对应多个gt的位置mask
            _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0) #找出单个anchor对应多个gt的位置mask中的最小值mask
            matching_matrix[:, multiple_match_mask] *= 0 #单个anchor对应多个gt的位置全置为0
            matching_matrix[cost_argmin, multiple_match_mask] = 1 #重复对应位置中的cost最小值为1,其余为0
        fg_mask_inboxes = anchor_matching_gt > 0
        num_fg = fg_mask_inboxes.sum().item()

        fg_mask[fg_mask.clone()] = fg_mask_inboxes

        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
        gt_matched_classes = gt_classes[matched_gt_inds]

        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
            fg_mask_inboxes
        ]
        return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值