FCOS 计算loss源码解读

本文详细介绍了FCOS检测器的损失计算过程,包括如何从原始框坐标生成损失函数所需的框样式,以及如何根据不同大小的框分配不同级别的特征图。通过阅读作者的源代码,揭示了一些优化技巧,如中心采样区域的计算,以及如何根据预设比例分配目标到不同特征层级。此外,还涵盖了损失函数的实现,如SigmoidFocalLoss和IOULoss,以及如何计算类别的损失、框回归的损失和中心度损失。
摘要由CSDN通过智能技术生成

最近在看FCOS论文总觉得不够具体,特此调试了源代码解读源代码以供自己以后查看。其中有很多技巧如果不是读作者源码是很难想到的。
包含一下内容:

  • 如何根据原始数据的box坐标生成loss函数需要的box样式
  • 如何根据大小不同box的分配不同level的特征图
    """
    This file contains specific functions for computing losses of FCOS
    file
    """
    
    import torch
    from torch.nn import functional as F
    from torch import nn
    import os
    from ..utils import concat_box_prediction_layers
    from fcos_core.layers import IOULoss
    from fcos_core.layers import SigmoidFocalLoss
    from fcos_core.modeling.matcher import Matcher
    from fcos_core.modeling.utils import cat
    from fcos_core.structures.boxlist_ops import boxlist_iou
    from fcos_core.structures.boxlist_ops import cat_boxlist
    
    
    INF = 100000000
    
    
    def get_num_gpus():
        return int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    
    
    def reduce_sum(tensor):
        if get_num_gpus() <= 1:
            return tensor
        import torch.distributed as dist
        tensor = tensor.clone()
        dist.all_reduce(tensor, op=dist.reduce_op.SUM)
        return tensor
    
    
    class FCOSLossComputation(object):
        """
        This class computes the FCOS losses.
        """
    
        def __init__(self, cfg):
            self.cls_loss_func = SigmoidFocalLoss(
                cfg.MODEL.FCOS.LOSS_GAMMA,
                cfg.MODEL.FCOS.LOSS_ALPHA
            )
            self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
            self.center_sampling_radius = cfg.MODEL.FCOS.CENTER_SAMPLING_RADIUS
            self.iou_loss_type = cfg.MODEL.FCOS.IOU_LOSS_TYPE
            self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS
    
            # we make use of IOU Loss for bounding boxes regression,
            # but we found that L1 in log scale can yield a similar performance
            self.box_reg_loss_func = IOULoss(self.iou_loss_type)
            self.centerness_loss_func = nn.BCEWithLogitsLoss(reduction="sum")
    
        def get_sample_region(self, gt, strides, num_points_per, gt_xs, gt_ys, radius=1.0):
            '''
            This code is from
            https://github.com/yqyao/FCOS_PLUS/blob/0d20ba34ccc316650d8c30febb2eb40cb6eaae37/
            maskrcnn_benchmark/modeling/rpn/fcos/loss.py#L42
            '''
            num_gts = gt.shape[0]
            K = len(gt_xs)
            gt = gt[None].expand(K, num_gts, 4)
            center_x = (gt[..., 0] + gt[..., 2]) / 2
            center_y = (gt[..., 1] + gt[..., 3]) / 2
            center_gt = gt.new_zeros(gt.shape)
            # no gt
            if center_x[..., 0].sum() == 0:
                return gt_xs.new_zeros(gt_xs.shape, dtype=torch.uint8)
            beg = 0
            for level, n_p in enumerate(num_points_per):
                end = beg + n_p
                stride = strides[level] * radius
                xmin = center_x[beg:end] - stride
                ymin = center_y[beg:end] - stride
                xmax = center_x[beg:end] + stride
                ymax = center_y[beg:end] + stride
                # limit sample region in gt
                center_gt[beg:end, :, 0] = torch.where(
                    xmin > gt[beg:end, :, 0], xmin, gt[beg:end, :, 0]
                )
                center_gt[beg:end, :, 1] = torch.where(
                    ymin > gt[beg:end, :, 1], ymin, gt[beg:end, :, 1]
                )
                center_gt[beg:end, :, 2] = torch.where(
                    xmax > gt[beg:end, :, 2],
                    gt[beg:end, :, 2], xmax
                )
                center_gt[beg:end, :, 3] = torch.where(
                    ymax > gt[beg:end, :, 3],
                    gt[beg:end, :, 3], ymax
                )
                beg = end
            left = gt_xs[:, None] - center_gt[..., 0]
            right = center_gt[..., 2] - gt_xs[:, None]
            top = gt_ys[:, None] - center_gt[..., 1]
            bottom = center_gt[..., 3] - gt_ys[:, None]
            center_bbox = torch.stack((left, top, right, bottom), -1)
            inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
            return inside_gt_bbox_mask
    
        def prepare_targets(self, points, targets):
            # 每一层特征图对应的检测的box的比例大小
            object_sizes_of_interest = [
                [-1, 64],
                [64, 128],
                [128, 256],
                [256, 512],
                [512, INF],
            ]
            # 创建5个tensor对应着5个特征图
            # 每个tensor的行数为对应的特征图的点的个数,列数为2
            # 例如第一个tensor的size为(len(points_per_level), 2),元素为重复的[-1, 64]
            expanded_object_sizes_of_interest = []
            for l, points_per_level in enumerate(points):
                object_sizes_of_interest_per_level = \
                    points_per_level.new_tensor(object_sizes_of_interest[l])
                expanded_object_sizes_of_interest.append(
                    object_sizes_of_interest_per_level[None].expand(
                        len(points_per_level), -1)
                )
            # expanded_object_sizes_of_interest维度
            # expanded_object_sizes_of_interest[0]为重复的[-1, 64]
            # expanded_object_sizes_of_interest[1]为重复的[64, 128]
            # [e.shape for e in expanded_object_sizes_of_interest]
            # [torch.Size([16128, 2]), torch.Size([4032, 2]), torch.Size([1008, 2]), torch.Size([252, 2]), torch.Size([66, 2])]
    
            # 在行的维度上concat
            expanded_object_sizes_of_interest = torch.cat(
                expanded_object_sizes_of_interest, dim=0)
            # 每层特征图的点的数量
            num_points_per_level = [len(points_per_level)
                                    for points_per_level in points]
            self.num_points_per_level = num_points_per_level
            points_all_level = torch.cat(points, dim=0)
    
            # 论文中分配box策略
            # 根据预设比例对不同level特征图分配不同的大小的box
            # labels:5层特征图每个point对应的label size:[torch.Size([21486])]
            # reg_targets:5层特征图每个point对应的box的四个坐标 size:[torch.Size([21486, 4])]
            labels, reg_targets = self.compute_targets_for_locations(
                points_all_level, targets, expanded_object_sizes_of_interest
            )
    
            # batch_size为len(labels)
            for i in range(len(labels)):
                labels[i] = torch.split(labels[i], num_points_per_level, dim=0)
                # 将label[0]切分为了[torch.Size([16128]), torch.Size([4032]), torch.Size([1008]), torch.Size([252]), torch.Size([66])]
                # 对应这5个特征图
                # reg_targets同理
                reg_targets[i] = torch.split(
                    reg_targets[i], num_points_per_level, dim=0)
    
            labels_level_first = []
            reg_targets_level_first = []
            # 将多个图片的同一level的label和reg_targets合并成一个tensor
            for level in range(len(points)):
                # 对一个batch内的所有图片的所有point在行的维度上进行concat
                # 后追加到labels_level_first
                labels_level_first.append(
                    torch.cat([labels_per_im[level]
                               for labels_per_im in labels], dim=0)
                )
    
                # 对一个batch内的所有图片的所有point在行的维度上进行concat
                reg_targets_per_level = torch.cat([
                    reg_targets_per_im[level]
                    for reg_targets_per_im in reg_targets
                ], dim=0)
    
                if self.norm_reg_targets:
                    # self.fpn_strides:[8, 16, 32, 64, 128]
                    reg_targets_per_level = reg_targets_per_level / \
                        self.fpn_strides[level]
                # 后追加给reg_targets_per_level
                reg_targets_level_first.append(reg_targets_per_level)
            # labels_level_first[0]:一个batch下所有图片的同一level的label
            # reg_targets_level_first:一个batch下所有图片的同一level的box坐标
            return labels_level_first, reg_targets_level_first
    
        def compute_targets_for_locations(self, locations, targets, object_sizes_of_interest):
            labels = []
            reg_targets = []
            xs, ys = locations[:, 0], locations[:, 1]
    
            for im_i in range(len(targets)):
                targets_per_im = targets[im_i]
                assert targets_per_im.mode == "xyxy"
                bboxes = targets_per_im.bbox
                labels_per_im = targets_per_im.get_field("labels")
                area = targets_per_im.area()
    
                l = xs[:, None] - bboxes[:, 0][None]
                t = ys[:, None] - bboxes[:, 1][None]
                r = bboxes[:, 2][None] - xs[:, None]
                b = bboxes[:, 3][None] - ys[:, None]
                reg_targets_per_im = torch.stack([l, t, r, b], dim=2)
    
                if self.center_sampling_radius > 0:
                    is_in_boxes = self.get_sample_region(
                        bboxes,
                        self.fpn_strides,
                        self.num_points_per_level,
                        xs, ys,
                        radius=self.center_sampling_radius
                    )
                else:
                    # no center sampling, it will use all the locations within a ground-truth box
                    is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0
    
                max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0]
                # limit the regression range for each location
                is_cared_in_the_level = \
                    (max_reg_targets_per_im >= object_sizes_of_interest[:, [0]]) & \
                    (max_reg_targets_per_im <= object_sizes_of_interest[:, [1]])
    
                locations_to_gt_area = area[None].repeat(len(locations), 1)
                locations_to_gt_area[is_in_boxes == 0] = INF
                locations_to_gt_area[is_cared_in_the_level == 0] = INF
    
                # if there are still more than one objects for a location,
                # we choose the one with minimal area
                locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min(
                    dim=1)
    
                reg_targets_per_im = reg_targets_per_im[range(
                    len(locations)), locations_to_gt_inds]
                labels_per_im = labels_per_im[locations_to_gt_inds]
                labels_per_im[locations_to_min_area == INF] = 0
    
                labels.append(labels_per_im)
                reg_targets.append(reg_targets_per_im)
    
            return labels, reg_targets
    
        def compute_centerness_targets(self, reg_targets):
            left_right = reg_targets[:, [0, 2]]
            top_bottom = reg_targets[:, [1, 3]]
            centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
                (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
            return torch.sqrt(centerness)
    
        def __call__(self, locations, box_cls, box_regression, centerness, targets):
            """
            Arguments:
                locations (list[BoxList])
                box_cls (list[Tensor])
                box_regression (list[Tensor])
                centerness (list[Tensor])
                targets (list[BoxList])
    
            Returns:
                cls_loss (Tensor)
                reg_loss (Tensor)
                centerness_loss (Tensor)
            """
            N = box_cls[0].size(0)
            num_classes = box_cls[0].size(1)
    
            # 该函数作用是根据location和target构造类别和box坐标的ground truth
            # location里面元素的维度
            # [torch.Size([15200, 2]), torch.Size([3800, 2]), torch.Size([950, 2]), torch.Size([247, 2]), torch.Size([70, 2])]
            # torch.Size([15200, 2]):0层特征图有15200个点,值为每个点对应的原图片坐标位置
            # targets
            # [BoxList(num_boxes=3, image_width=1201, image_height=800, mode=xyxy)]
            # labels
            # [torch.Size([15200]), torch.Size([3800]), torch.Size([950]), torch.Size([247]), torch.Size([70])]
            # reg_targets
            # [torch.Size([15200, 4]), torch.Size([3800, 4]), torch.Size([950, 4]), torch.Size([247, 4]), torch.Size([70, 4])]
            labels, reg_targets = self.prepare_targets(locations, targets)
    
            box_cls_flatten = []
            box_regression_flatten = []
            centerness_flatten = []
            labels_flatten = []
            reg_targets_flatten = []
    
            # 对于每层的特征图的预测的box、cls和GT的box、cls进行reshape
            for l in range(len(labels)):
                box_cls_flatten.append(box_cls[l].permute(
                    0, 2, 3, 1).reshape(-1, num_classes))
                box_regression_flatten.append(
                    box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4))
                labels_flatten.append(labels[l].reshape(-1))
                reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
                centerness_flatten.append(centerness[l].reshape(-1))
    
            # 在行的维度上进行增加
            box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
            box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
            centerness_flatten = torch.cat(centerness_flatten, dim=0)
            labels_flatten = torch.cat(labels_flatten, dim=0)
            reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)
            # 维度
            # box_cls_flatten:torch.Size([20267, 80])
            # box_regression_flatten:torch.Size([20267, 4])
            # labels_flatten:torch.Size([20267])
            # reg_targets_flatten:torch.Size([20267, 4])
            # centerness_flatten : torch.Size([20267])
    
            # 提取有类别的特征图中的点
            pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1)
            # pos_inds : tensor([ 7971,  7972,  7973,  8123,  8124,  8125,  8275,  8276,  8277, 17133,
            # 17134, 17135, 20057, 20058, 20059, 20068, 20069, 20070, 20076, 20077,
            # 20078, 20087, 20088, 20089, 20095, 20096, 20097, 20106, 20107, 20108,
            # 20243, 20244], device='cuda:0')
    
            # 根据pos_inds提取正样本
            box_regression_flatten = box_regression_flatten[pos_inds]
            reg_targets_flatten = reg_targets_flatten[pos_inds]
            centerness_flatten = centerness_flatten[pos_inds]
    
            num_gpus = get_num_gpus()
            # sync num_pos from all gpus
            total_num_pos = reduce_sum(
                pos_inds.new_tensor([pos_inds.numel()])).item()
            num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0)
    
            # 分类loss:SigmoidFocalLoss
            cls_loss = self.cls_loss_func(
                box_cls_flatten,
                labels_flatten.int()
            ) / num_pos_avg_per_gpu
    
            if pos_inds.numel() > 0:
                # 计算target的centerness
                centerness_targets = self.compute_centerness_targets(
                    reg_targets_flatten)
    
                # average sum_centerness_targets from all gpus,
                # which is used to normalize centerness-weighed reg loss
                sum_centerness_targets_avg_per_gpu = \
                    reduce_sum(centerness_targets.sum()).item() / float(num_gpus)
    
                # 计算box坐标的loss
                reg_loss = self.box_reg_loss_func(
                    box_regression_flatten,
                    reg_targets_flatten,
                    centerness_targets
                ) / sum_centerness_targets_avg_per_gpu
                # 计算centernes loss
                centerness_loss = self.centerness_loss_func(
                    centerness_flatten,
                    centerness_targets
                ) / num_pos_avg_per_gpu
            else:
                # 如果图片没有box的情况
                reg_loss = box_regression_flatten.sum()
                reduce_sum(centerness_flatten.new_tensor([0.0]))
                centerness_loss = centerness_flatten.sum()
    
            return cls_loss, reg_loss, centerness_loss
    
    
    def make_fcos_loss_evaluator(cfg):
        loss_evaluator = FCOSLossComputation(cfg)
        return loss_evaluator
    
    

    转载原文:https://blog.csdn.net/sunlanchang/article/details/103962394

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值