【深度学习】【mmdetection】FCOS代码阅读二

本文深入探讨了FCOS目标检测模型的损失函数,分析了get_points的细节,指出正样本的选择标准以及center sampling的重要性。同时,解释了centerness_target如何量化目标中心与预测位置的距离,该值仅对正样本进行计算。
摘要由CSDN通过智能技术生成

【mmdetection】FCOS

损失函数

下面我们再来看看函数里面定义的损失函数。

    @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        '''
        cls_scores: [5][batchsize,80,H_i,W_i]
        bbox_preds: [5][batchsize,4,H_i,W_i]
        centernesses: [5][batchsize,1,H_i,W_i]
        gt_bboxes: [batchsize][num_obj,4]
        gt_labels: [batchsize][num_obj]
        img_metas: [batchsize][(dict)dict_keys(['filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'img_norm_cfg'])]
        cfg: {'assigner': {'type': 'MaxIoUAssigner', 'pos_iou_thr': 0.5, 'neg_iou_thr': 0.4, 'min_pos_iou': 0, 'ignore_iof_thr': -1}, 'allowed_border': -1, 'pos_weight': -1, 'debug': False}
        '''
        assert len(cls_scores) == len(bbox_preds) == len(centernesses) # 5

        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] # P3-P7特征图的大小
        '''
        [torch.Size([100, 152]),
        torch.Size([50, 76]),
        torch.Size([25, 38]),
        torch.Size([13, 19]),
        torch.Size([7, 10])]
        '''
        
        # 特征图的大小就相当于把原图分为多大的grid,特征图每个像素映射到原图就是该grid的中心点,不同大小的特征图就有不同的grid
        # bbox_preds[0].dtype:torch.float32
        # all_level_points:(list) [5][n_points][2]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
                                                gt_labels)
        '''
        labels:[5][batch_size*level_points_i]
        bbox_targets:[5][batch_size*level_points_i,4]
        '''
        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores) # torch.Size([89600, 80]) 所有图片所有point的5个层的输出
        flatten_bbox_preds = torch.cat(flatten_bbox_preds) # torch.Size([89600, 4]) 
        flatten_centerness = torch.cat(flatten_centerness) # torch.Size([89600]) 
        flatten_labels = torch.cat(labels) # torch.Size([89600]) 
        flatten_bbox_targets = torch.cat(b
  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值