SimOTA 正负样本分配策略
请带着下面4个要点进行阅读
如下图 gt0,gt1,gt2为真实框,以真实框的中心为中心五倍的格子为边长(五倍是以 原始特征图为参照 )形成一个下图的蓝色的正方形。(代码中是放大到了真实大小)。
以下图为例有20x20=400个anchor point(即grid cell ) 每个anchor point会预测一组tx,ty,tw,th,obj 如上图(在当前特征层中计算出 预测框的 中心点和w,h ,由于为anchor_free故算wh时不用再乘 anchor 模版的w h)。如果anchor point 中心点落在 gt和正方形形成并集区域内 那么这些 anchor point 就有可能成为正样本,是不是真正的正样本还需要进行筛选。
将中心点落在 gt和正方形形成并集的区域内 的anchor point 的预测值进行处理 ,计算出每个预测值与 gt 的cost,计算方式如下
pair_wise_cls_loss: 为anchor point 与gt的分类损失
pair_wise_iou :为anchor point 与gt的回归损失(位置损失)
~is_in_boxes_and_center: 为判断 中心点是否落在 gt和正方形形成的交集区域内 ,如果在交集内则加 0 不在 则加 10000(本质是为了在不断训练中 让 形成的预测框 形成在交集内部 )
最终可以得到每个gt 与其相应anchor point 的cost 得分,在计算cost时其实也把相应的 anchor point 与 gt的iou损失计算出来了(即pair_wise_iou)(A1,A2…为anchor point)
之后每个gt 从大到小 取出 不大于10个iou值 进行相加 ,并向下取整,即为每个gt 可以获得 anchor point 数量 ,之后将cost 从小 到大 每个gt取出相应数量的anchor point ,并形成一个混淆矩阵,每个被选到anchor point 下面填 0,没被选上的填 1
如果同一个 anchor point 对应 两个gt时,会选取 cost 较低的 anchor point 另一个会被 填0
**
至此 每个gt 对应的正样本就选出来了,之后计算总的损失loss
**
原文代码如下:
# 计算 正样本和对应gt框的 iou损失(回归损失)
# 源码 if self.loss_type == "iou":
# loss = 1 - iou ** 2
loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()
# 计算 正样本和对应gt框的置信度损失(正样本与1比较)+负样本置信度损失(负样本与 0比较)
loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()
# !!!!!!! 这边的分类损失与cost 中 不同是用的 正样本与gt的iou 作为cls_targets
# cls_targets :F.one_hot(gt_matched_classes.to(torch.int64),
# self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1) (即正样本框 与 gt框的 iou)
# 我猜测是为了让分类正确预测框 更加逼近 真实框
loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
reg_weight = 5.0
loss = reg_weight * loss_iou + loss_obj + loss_cls
# 除以正样本数
loss / num_fg
即下面这张图:
本文参考:
https://www.bilibili.com/video/BV1JW4y1k76c/?spm_id_from=333.880.my_history.page.click&vd_source=9c63f89b714e96dfc638093fbe9f907d
https://zhuanlan.zhihu.com/p/549382358
https://zhuanlan.zhihu.com/p/609370771
https://www.bilibili.com/video/BV1d34y1q7XC/?spm_id_from=333.788&vd_source=9c63f89b714e96dfc638093fbe9f907d