YOLOX源码解读---自有数据集消融实验

本文详细解读YOLOX网络结构及其改进方法,包括SimOTA标签分配策略、multi positives、anchor-free、decoupled head和strong augmentation。通过消融实验,探讨这些组件在YOLOX中的作用,特别是在SSDD数据集上的训练效果。
摘要由CSDN通过智能技术生成

YOLOX网络结构:
在这里插入图片描述

YOLOX相比于YOLOv3的主要改进方法
在这里插入图片描述

YOLOX从零开始训练SSDD数据集的结果:
TODO

采用与YOLOX相反的顺序做消融实验,直到回归到yolov3 baseline的水平

SimOTA

 # 设置候选框的默认数量为10,如果初步筛选的数量小于10,则设定为初步筛选出的框的数量
        n_candidate_k = min(1, ious_in_boxes_matrix.size(1)) 
        # 从前面的pair_wise_ious中,给每个目标框,挑选10个iou最大的候选框。
        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1) # [3, 10]

主要解决标签分配的问题。
在YOLOX中正样本锚框的挑选分为两个步骤:初步筛选 + SimOTA

初步筛选:

  1. 根据anchor中心点判断:即anchor的中心点位于gt_boxes内,则满足正样本锚框的条件
  2. 根据gt_box中心点判断:以gt_boxes的中心点为中心,生成一个边长为5的正方形,
    如果anchor的中心点位于正方形内,则满足正样本锚框的条件
  3. 实际上返回的fg_mask是is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all,即满足其中一个正样本的条件即可
    def get_in_boxes_info(
        self,
        gt_bboxes_per_image,
        expanded_strides,
        x_shifts,
        y_shifts,
        total_num_anchors,
        num_gt,
    ):
        # 1.根据中心点判断:即anchor的中心点位于gt_boxes内,则满足正样本锚框的条件
        expanded_strides_per_image = expanded_strides[0]
        x_shifts_per_image = x_shifts[0] * expanded_strides_per_image
        y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
        x_centers_per_image = (
            (x_shifts_per_image + 0.5 * expanded_strides_per_image)
            .unsqueeze(0)
            .repeat(num_gt, 1)
        )  # [n_anchor] -> [n_gt, n_anchor]
        y_centers_per_image = (
            (y_shifts_per_image + 0.5 * expanded_strides_per_image)
            .unsqueeze(0)
            .repeat(num_gt, 1)
        )

        gt_bboxes_per_image_l = (
            (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
            .unsqueeze(1)
            .repeat(1, total_num_anchors)
        )
        gt_bboxes_per_image_r = (
            (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
            .unsqueeze(1)
            .repeat(1, total_num_anchors)
        )
        gt_bboxes_per_image_t = (
            (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
            .unsqueeze(1)
            .repeat(1, total_num_anchors)
        )
        gt_bboxes_per_image_b = (
            (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
            .unsqueeze(1)
            .repeat(1, total_num_anchors)
        )

        b_l = x_centers_per_image - gt_bboxes_per_image_l
        b_r = gt_bboxes_per_image_r - x_centers_per_image
        b_t = y_centers_per_image - gt_bboxes_per_image_t
        b_b = gt_bboxes_per_image_b - y_centers_per_image
        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)

        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
        # in fixed center

        # 2.根据目标框判断:以gt_boxes的中心点为中心,生成一个边长为5的正方形,
        #   如果anchor的中心点位于正方形内,则满足正样本锚框的条件
        center_radius = 2.5

        gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) - center_radius * expanded_strides_per_image.unsqueeze(0)
        gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) + center_radius * expanded_strides_per_image.unsqueeze(0)
        gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) - center_radius * expanded_strides_per_image.unsqueeze(0)
        gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) + center_radius * expanded_strides_per_image.unsqueeze(0)

        c_l = x_centers_per_image - gt_bboxes_per_image_l
        c_r = gt_bboxes_per_image_r - x_centers_per_image
        c_t = y_centers_per_image - gt_bboxes_per_image_t
        c_b = gt_bboxes_per_image_b - y_centers_per_image
        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
        is_in_centers = center_deltas.min(dim=-1).values > 0.0
        is_in_centers_all = is_in_centers.sum(dim=0) > 0

        # in boxes and in centers
        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all

        is_in_boxes_and_center = (
            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
        )
        # 3. 实际上返回的fg_mask是is_in_boxes_anchor,即满足其中一个正样本的条件即可
        return is_in_boxes_anchor, is_in_boxes_and_center

SimOTA
主要分成四个步骤:

  1. 提取出初步筛选的正样本锚框
  2. 计算初步筛选的正样本锚框与gt_boxes的pair_wise_ious_loss
  3. 计算cost成本:对cls_preds、reg_loss(pair_wise_ious_loss)、is_in_boxes_and_center(同时满足初步筛选的两个条件)进行加权求和
  4. 进行动态标签分配
        # SimOTA
        # 1. 提取出初步筛选的正样本锚框
        bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
        cls_preds_ = cls_preds[batch_idx][fg_mask]
        obj_preds_ = obj_preds[batch_idx][fg_mask]
        num_in_boxes_anchor = bboxes_preds_per_image.shape[0]

        if mode == "cpu":
            gt_bboxes_per_image = gt_bboxes_per_image.cpu()
            bboxes_preds_per_image = bboxes_preds_per_image.cpu()

        # 2. 计算初步筛选的正样本锚框与gt_boxes的pair_wise_ious_loss
        pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)

        gt_cls_per_image = (
            F.one_hot(gt_classes.to(torch.int64), self.num_classes)
            .float()
            .unsqueeze(1)
            .repeat(1, num_in_boxes_anchor, 1)
        )
        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)

        if mode == "cpu":
            cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()

        # 3. 计算cost成本:对cls_preds、reg_loss(pair_wise_ious_loss)、is_in_boxes_and_center(同时满足初步筛选的两个条件)进行加权求和
        with torch.cuda.amp.autocast(enabled=False):
            cls_preds_ = (
                cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
                * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
            )
            pair_wise_cls_loss = F.binary_cross_entropy(
                cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
            ).sum(-1)
        del cls_preds_

        cost = (
            pair_wise_cls_loss
            + 3.0 * pair_wise_ious_loss
            + 100000.0 * (~is_in_boxes_and_center)
        )

        # 4. 进行动态标签分配
        (
            num_fg,
            gt_matched_classes,
            pred_ious_this_matching,
            matched_gt_inds,
        ) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss

其中动态标签分配是SimOTA的核心,分配过程如下:

第一步:设置候选框数量
假设初步筛选出1000个正样本anchor,有3个gt_boxes

        # 1. 设置候选框的数量

        # 创建全为零的初始矩阵,用来保存最终的分配结果
        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) # [3, 1000]

        ious_in_boxes_matrix = pair_wise_ious # [3, 1000]
        # 设置候选框的默认数量为10,如果初步筛选的数量小于10,则设定为初步筛选出的框的数量
        n_candidate_k = min(10, ious_in_boxes_matrix.size(1)) 
        # 从前面的pair_wise_ious中,给每个目标框,挑选10个iou最大的候选框。
        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1) # [3, 10]

第二步:通过cost挑选候选框

        # 2. 通过cost挑选候选框

        # 通过对topk_ious进行求和,得到gt_boxes对应的正样本anchor的动态数量dynamic_ks
        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
        dynamic_ks = dynamic_ks.tolist()
        # 为每个gt_boxes挑选出cost值最低dynamic_ks个候选框
        for gt_idx in range(num_gt):
            _, pos_idx = torch.topk(
                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
            )
            matching_matrix[gt_idx][pos_idx] = 1

        del topk_ious, dynamic_ks, pos_idx

通过对topk_ious进行求和,得到gt_boxes对应的正样本anchor的动态数量dynamic_ks的过程图示:
在这里插入图片描述
为每个gt_boxes挑选出cost值最低dynamic_ks个候选框图示
在这里插入图片描述
得到的matching_matrix中,cost值最低的一些位置,数值为1,其余位置都为0。
目标框1和目标框3,dynamic_ks值都为3,因此matching_matrix的第一行和第三行,有3个1。
目标框2,dynamic_ks值为4,因此matching_matrix的第二行,有4个1。

第三步:过滤共用的候选框
如上图的matching_matrix中,第5列有两个1。
这也就说明,第五列所对应的候选框,被目标检测框1和2,都进行关联。
因此对这两个位置,还要使用cost值进行对比,选择较小的值,再进一步筛选。

  # 3. 过滤共用的候选框

        # 对matching_matrix每一列进行相加,得到一个anchor所匹配的gt_boxes数量
        anchor_matching_gt = matching_matrix.sum(0)
        # 更小的cost保持为1,更大的cost置为0
        if (anchor_matching_gt > 1).sum() > 0:
            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
            matching_matrix[:, anchor_matching_gt > 1] *= 0
            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1

具体流程如下:
对matching_matrix每一列进行相加,得到一个anchor所匹配的gt_boxes数量
在这里插入图片描述
更小的cost保持为1,更大的cost置为0
在这里插入图片描述

在这里插入图片描述

消融实验的实现实际上将候选框的默认数量设置为1即可

n_candidate_k = min(1, ious_in_boxes_matrix.size(1)) 

multi positives

        center_radius = .6 # single positive
        # center_radius = 2.5 # multi positives

实际上就是在标签分配的初步筛选过程中的根据gt_box中心点判断:以gt_boxes的中心点为中心,生成一个边长为5的正方形, 如果anchor的中心点位于正方形内,则满足正样本锚框的条件
如果是single positive,可以理解只由gt_boxes中心所在的pixel产生的anchor进行回归
而multi positives,则是采用gt_boxes中心周围的pixels产生的anchors进行回归
这样做的目的是,增加了正样本anchor的数量

因此multi positives的消融实验,可以通过将正方形的边长设为1实现

anchor-free

因为最终目的是回归旋转框,因此暂时不做此消融实验,以后再填坑

decoupled head

在这里插入图片描述
解耦头的意思是对cls_output、obj_output、reg_output进行解耦,用三个分支输出
yolox中的实现方式:使用 1个1x1 的卷积先进行降维,并在后面两个分支里,各使用了 2个3x3 卷积,最终调整到仅仅增加一点点的网络参数。
在这里插入图片描述

在yolo_head.py中decoupled head实现方式如下:
在这里插入图片描述
用于降维的1*1卷积为:
在这里插入图片描述
分类解耦分支中的两个3 * 3的卷积:
在这里插入图片描述

回归解耦分支中的两个3 * 3的卷积
在这里插入图片描述

因此去掉decoupled head的实现方式为:
为了保持后面网络的维度统一,保留1*1的降维卷积

            x = self.stems[k](x) # 1 * 1的降维卷积
          
            cls_output = self.cls_preds[k](x) 
            reg_output = self.reg_preds[k](x) 
            obj_output = self.obj_preds[k](x)

参考文章:YOLOX解读

strong augmentation

在yolo_base.py中,修改self.no_aug_epochs即可

        # last #epoch to close augmention like mosaic
        # self.no_aug_epochs = 15
        self.no_aug_epochs = 120

相比于YOLOv3,YOLOX的strong augmentation主要是使用了Mosaic、Mixup两种数据增强方法
值得注意的是,由于采取了更强的数据增强方式,作者在研究中发现,ImageNet预训练将毫无意义,因此,所有的模型,均是从头开始训练的。

loss function

 loss_iou = (
            self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)
        ).sum() / num_fg
 loss_obj = (
            self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)
        ).sum() / num_fg
 loss_cls = (
            self.bcewithlog_loss(
                cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets
            )
        ).sum() / num_fg

检测框的loss_iou :iou_loss vs. giou_loss,默认是iou_loss
值得注意的是,iou_loss和cls_loss,只针对gt_boxes与初步筛选得到的正样本anchor进行计算;而obj_loss,是对所有的anchor计算loss
因此,可以尝试将loss_obj的损失函数从bcewithlog_loss替换成focal loss

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值