YOLOX的损失函数(SimOTA 正负样本分配策略)

本文详细描述了SimOTA算法中如何确定正负样本,包括基于真实框的anchorpoint处理、成本计算(包含IoU损失和置信度损失),以及最终的混淆矩阵构建过程。作者还提供了相关代码片段以解释这一过程。
摘要由CSDN通过智能技术生成

SimOTA 正负样本分配策略

请带着下面4个要点进行阅读
在这里插入图片描述

在这里插入图片描述

如下图 gt0,gt1,gt2为真实框,以真实框的中心为中心五倍的格子为边长(五倍是以 原始特征图为参照 )形成一个下图的蓝色的正方形。(代码中是放大到了真实大小)。
以下图为例有20x20=400个anchor point(即grid cell ) 每个anchor point会预测一组tx,ty,tw,th,obj 如上图在当前特征层中计算出 预测框的 中心点和w,h ,由于为anchor_free故算wh时不用再乘 anchor 模版的w h)。如果anchor point 中心点落在 gt和正方形形成并集区域内 那么这些 anchor point 就有可能成为正样本,是不是真正的正样本还需要进行筛选。
在这里插入图片描述
将中心点落在 gt和正方形形成并集的区域内 的anchor point 的预测值进行处理 ,计算出每个预测值与 gt 的cost,计算方式如下
在这里插入图片描述
pair_wise_cls_loss: 为anchor point 与gt的分类损失
pair_wise_iou :为anchor point 与gt的回归损失(位置损失)
~is_in_boxes_and_center: 为判断 中心点是否落在 gt和正方形形成的交集区域内 ,如果在交集内则加 0 不在 则加 10000(本质是为了在不断训练中 让 形成的预测框 形成在交集内部 )

最终可以得到每个gt 与其相应anchor point 的cost 得分,在计算cost时其实也把相应的 anchor point 与 gt的iou损失计算出来了(即pair_wise_iou)(A1,A2…为anchor point)

在这里插入图片描述
之后每个gt 从大到小 取出 不大于10个iou值 进行相加 ,并向下取整,即为每个gt 可以获得 anchor point 数量 ,之后将cost 从小 到大 每个gt取出相应数量的anchor point ,并形成一个混淆矩阵,每个被选到anchor point 下面填 0,没被选上的填 1
在这里插入图片描述在这里插入图片描述如果同一个 anchor point 对应 两个gt时,会选取 cost 较低的 anchor point 另一个会被 填0
在这里插入图片描述

**

至此 每个gt 对应的正样本就选出来了,之后计算总的损失loss

**
原文代码如下:

        # 计算 正样本和对应gt框的 iou损失(回归损失)
        #  源码      if self.loss_type == "iou":
        #               loss = 1 - iou ** 2
        loss_iou    = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()
        # 计算 正样本和对应gt框的置信度损失(正样本与1比较)+负样本置信度损失(负样本与 0比较)
        loss_obj    = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()
        # !!!!!!! 这边的分类损失与cost 中 不同是用的 正样本与gt的iou 作为cls_targets
        #   cls_targets  :F.one_hot(gt_matched_classes.to(torch.int64),    
        #                            self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1)    (即正样本框 与 gt框的 iou)
        #  我猜测是为了让分类正确预测框 更加逼近 真实框
        loss_cls    = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
        reg_weight  = 5.0
        loss = reg_weight * loss_iou + loss_obj + loss_cls
        # 除以正样本数
        loss / num_fg

即下面这张图:
在这里插入图片描述
本文参考
https://www.bilibili.com/video/BV1JW4y1k76c/?spm_id_from=333.880.my_history.page.click&vd_source=9c63f89b714e96dfc638093fbe9f907d
https://zhuanlan.zhihu.com/p/549382358
https://zhuanlan.zhihu.com/p/609370771
https://www.bilibili.com/video/BV1d34y1q7XC/?spm_id_from=333.788&vd_source=9c63f89b714e96dfc638093fbe9f907d

  • 12
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值