BorderDet: Border Feature for Dense Object Detection论文解读-代码pytorch

BorderDet: Border Feature for Dense Object Detection

兄弟们,:-O来啦~ 美好的一天哇

论文地址1: https://arxiv.org/pdf/2007.11056.pdf
代码地址2: https://github.com/MegviiBaseDetection/BorderDet
作者分析3: https://zhuanlan.zhihu.com/p/163044323

其实作者已经做了一些分析,但是自己有的地方还是不懂,所以就着代码读了读,感觉一些细节补充一下,欢迎纠正~

概述

该文提出了一种非常简单、高效的操作来提取物体边界极限点的特征,叫做“BorderAlign”。模型只增加很少的时间开销,可以在经典模型上FCOS(38.6 v.s. 41.4). FPN(37.1 v.s. 40.7)。

背景

对于最终检测框的描述有很多种,本文提出了一种方式,可以比简单的单一的中心点要好.

如下图,这个运动员中心的五角星位置即为anchor点,但是确定该物体边界框的主要是边界上的四个橘色圆点,这个运动员的边界框的位置主要由四个极限点来确定。用其他的方法可能会引入一些有害的信息,且不能直接有效的提取到真正有用的边界极限点。
在这里插入图片描述

动机

基于上述两条“痛点”,我们觉得提取物体的边界极限点的特征是不是能对物体精准定位有一些帮助?于是乎我们做了最基础的一些实验,我们基于FCOS的检测器,增加一些enhancement的特征来加强单点特征。主要有4组实验,分别来验证上面提到的两条问题。(1) single point: 单点特征做增强; (2) region: 用ROIAlign提取框内所有特征来增强; (3)border使用边界上所有点的特征来增强 (4)只用边界中心点来增强。这四个实验的特征采样位置如下图,分别对应不同的采样点个数。[引用于3]

我们发现,提取边界的中心点的特征,能够达到和region feature同样的结果,且采样点个数少了很多。这意味着只需要更少的复杂度,能够高效的提取到有用的特征。且同时证明了,边界极限点特征对物体定位确实有非常重要的作用。[引用于3]

本工作一个比较大的亮点是,打破了人们以往的认知,直击对ranking和框的精修比较重要的因素——边界极限点特征,并提供了一套简单、通用的操作来提取极限点的特征,为目标检测领域中检测框特征表达提供了一个全新的思路![引用于3]

在这里插入图片描述

方法

在这里插入图片描述
我们这里仔细研读一下这个结构;

BorderAlign

这是本工作最核心的一个操作,用来显式、自适应的提取物体边界的特征。如上图右上角。对于一个特征图,通道个数为5xC**(C是channel数,代码里对于类别C为256,对于回归框C为128)**,这是一个border-sensitive的特征图,分别对应物体4个边界特征和原始anchor点位置的特征。对于一个anchor点预测的一个框,我们把这个框的4个border对应在特征图上的特征分别做pooling操作。且由于框的位置是小数,所以该操作使用双线性插值取出每个border上的特征。如图所示,我们每条边会先选出5个待采样点,再对这5个待采样点取最大的值,作为该条边的特征,即每条边最后只会选出一个采样点作为输出。那么每个anchor点都会采样5个点的特征作为输出,即输出的通道数也为5xC个。

        self.add_module("border_cls_subnet", BorderBranch(in_channels, 256))
        self.add_module("border_bbox_subnet", BorderBranch(in_channels, 128))

对于每个样本的groudtruth来说,它的输入不仅仅只有:图像,bbox,还有前面feature_map预测出来的粗糙的框(coarse bbox),如下面代码所示,shifts是图片,targets就是instance,pre_boes_list:就是之前的feataure_map预测出来的粗糙框,那么粗糙的框是干嘛用的呢?

    @torch.no_grad()
    def get_ground_truth(self, shifts, targets, pre_boxes_list):
        """
        Args:
            shifts (list[list[Tensor]]): a list of N=#image elements. Each is a
                list of #feature level tensors. The tensors contains shifts of
                this image on the specific feature level.
            targets (list[Instances]): a list of N `Instances`s. The i-th
                `Instances` contains the ground-truth per-instance annotations
                for the i-th input image.  Specify `targets` during training only.

        Returns:
            gt_classes (Tensor):
                An integer tensor of shape (N, R) storing ground-truth
                labels for each shift.
                R is the total number of shifts, i.e. the sum of Hi x Wi for all levels.
                Shifts in the valid boxes are assigned their corresponding label in the
                [0, K-1] range. Shifts in the background are assigned the label "K".
                Shifts in the ignore areas are assigned a label "-1", i.e. ignore.
            gt_shifts_deltas (Tensor):
                Shape (N, R, 4).
                The last dimension represents ground-truth shift2box transform
                targets (dl, dt, dr, db) that map each shift to its matched ground-truth box.
                The values in the tensor are meaningful only when the corresponding
                shift is labeled as foreground.
            gt_centerness (Tensor):
                An float tensor (0, 1) of shape (N, R) whose values in [0, 1]
                storing ground-truth centerness for each shift.
            border_classes (Tensor):
                An integer tensor of shape (N, R) storing ground-truth
                labels for each shift.
                R is the total number of shifts, i.e. the sum of Hi x Wi for all levels.
                Shifts in the valid boxes are assigned their corresponding label in the
                [0, K-1] range. Shifts in the background are assigned the label "K".
                Shifts in the ignore areas are assigned a label "-1", i.e. ignore.
            border_shifts_deltas (Tensor):
                Shape (N, R, 4).
                The last dimension represents ground-truth shift2box transform
                targets (dl, dt, dr, db) that map each shift to its matched ground-truth box.
                The values in the tensor are meaningful only when the corresponding
                shift is labeled as foreground.

        """
        # 在获得gt的时候,除了img,box,还需要pre_boxes_list
        print("shifts ", len(shifts), len(shifts[0]),len(shifts[0][0]),len(shifts[0][0][0]))
        gt_classes = []
        gt_shifts_deltas = []
        gt_centerness = []

        border_classes = []
        border_shifts_deltas = []

        for shifts_per_image, targets_per_image, pre_boxes in zip(shifts, targets, pre_boxes_list):
            object_sizes_of_interest = torch.cat([
                shifts_i.new_tensor(size).unsqueeze(0).expand(
                    shifts_i.size(0), -1) for shifts_i, size in zip(
                    shifts_per_image, self.object_sizes_of_interest)
            ], dim=0)

            shifts_over_all_feature_maps = torch.cat(shifts_per_image, dim=0)
            # [3756, 2] [4092, 2])
            print("shifts_over_all_feature_maps ", shifts_over_all_feature_maps.size())
            gt_boxes = targets_per_image.gt_boxes

            deltas = self.shift2box_transform.get_deltas(
                shifts_over_all_feature_maps, gt_boxes.tensor.unsqueeze(1))
            print("deltas ", deltas.size())
            if self.center_sampling_radius > 0:
                centers = gt_boxes.get_centers()
                is_in_boxes = []
                for stride, shifts_i in zip(self.fpn_strides, shifts_per_image):
                    radius = stride * self.center_sampling_radius
                    center_boxes = torch.cat((
                        torch.max(centers - radius, gt_boxes.tensor[:, :2]),
                        torch.min(centers + radius, gt_boxes.tensor[:, 2:]),
                    ), dim=-1)
                    center_deltas = self.shift2box_transform.get_deltas(
                        shifts_i, center_boxes.unsqueeze(1))
                    is_in_boxes.append(center_deltas.min(dim=-1).values > 0)
                is_in_boxes = torch.cat(is_in_boxes, dim=1)
            else:
                # no center sampling, it will use all the locations within a ground-truth box
                is_in_boxes = deltas.min(dim=-1).values > 0

            max_deltas = deltas.max(dim=-1).values
            # limit the regression range for each location
            is_cared_in_the_level = \
                (max_deltas >= object_sizes_of_interest[None, :, 0]) & \
                (max_deltas <= object_sizes_of_interest[None, :, 1])

            gt_positions_area = gt_boxes.area().unsqueeze(1).repeat(
                1, shifts_over_all_feature_maps.size(0))
            gt_positions_area[~is_in_boxes] = math.inf
            gt_positions_area[~is_cared_in_the_level] = math.inf

            # if there are still more than one objects for a position,
            # we choose the one with minimal area
            positions_min_area, gt_matched_idxs = gt_positions_area.min(dim=0)

            # ground truth box regression
            gt_shifts_reg_deltas_i = self.shift2box_transform.get_deltas(
                shifts_over_all_feature_maps, gt_boxes[gt_matched_idxs].tensor)

            # ground truth classes
            has_gt = len(targets_per_image) > 0
            if has_gt:
                gt_classes_i = targets_per_image.gt_classes[gt_matched_idxs]
                # Shifts with area inf are treated as background.
                gt_classes_i[positions_min_area == math.inf] = self.num_classes
            else:
                gt_classes_i = torch.zeros_like(gt_matched_idxs) + self.num_classes

            # ground truth centerness
            left_right = gt_shifts_reg_deltas_i[:, [0, 2]]
            top_bottom = gt_shifts_reg_deltas_i[:, [1, 3]]
            gt_centerness_i = torch.sqrt(
                (left_right.min(dim=-1).values / left_right.max(dim=-1).values).clamp_(min=0)
                * (top_bottom.min(dim=-1).values / top_bottom.max(dim=-1).values).clamp_(min=0)
            )

            gt_classes.append(gt_classes_i)
            gt_shifts_deltas.append(gt_shifts_reg_deltas_i)
            gt_centerness.append(gt_centerness_i)

            # border
            iou = pairwise_iou(Boxes(pre_boxes), gt_boxes)
            (max_iou, argmax_iou) = iou.max(dim=1)
            invalid = max_iou < self.border_iou_thresh
            gt_target = gt_boxes[argmax_iou].tensor

            border_cls_target = targets_per_image.gt_classes[argmax_iou]
            border_cls_target[invalid] = self.num_classes

            border_bbox_std = pre_boxes.new_tensor(self.border_bbox_std)
            pre_boxes_wh = pre_boxes[:, 2:4] - pre_boxes[:, 0:2]
            pre_boxes_wh = torch.cat([pre_boxes_wh, pre_boxes_wh], dim=1)
            print("gt_target ", list(gt_target.size()) )  # [4092, 4] [3756, 4] [4440, 4]

            border_off_target = (gt_target - pre_boxes) / (pre_boxes_wh * border_bbox_std)

            border_classes.append(border_cls_target)
            border_shifts_deltas.append(border_off_target)

        return (
            torch.stack(gt_classes),
            torch.stack(gt_shifts_deltas),
            torch.stack(gt_centerness),
            torch.stack(border_classes),
            torch.stack(border_shifts_deltas),
        )

border_off_target = (gt_target - pre_boxes) / (pre_boxes_wh * border_bbox_std)

这句代码,就是根据论文中公式(2)来求出的off值, ( x 0 , y 0 , x 1 , y 1 ) (x_0,y_0,x_1,y_1) (x0,y0,x1,y1)是预测出来的粗糙的框,并且IOU>0.6的正样本,这样的粗糙的框会和 ( x 0 t , y 0 t , x 1 t , y 1 t ) (x^t_0,y^t_0,x^t_1,y^t_1) (x0t,y0t,x1t,y1t) groundtruth真实值会计算公式二,获得偏差; 而这个偏差,就是BorderBranch中border_bbox_pred分支所要学习的东西;

在这里插入图片描述

最终losses返回的函数,一共有5个组成部分,其中 “loss_cls” “loss_box_reg” “loss_centerness"就是基于粗糙框基于FCOS的三个loss, “loss_border_cls”, “loss_border_reg” 就是border分支下新的类别loss和框的回归loss. 这里, 跟人觉得"loss_cls”“loss_border_cls” 所对应的groundtruth是一样的.都是grouthtruth列别. 但是两个分支所学习的东西有可能有差别.

最终的loss函数如下: 前一部分是coarse的loss, 后一部分就是border_branch分支的loss,这里N_pos就是IOU>0.6的positive的框才会进行运算. 这里的△*就是border_off_target;
在这里插入图片描述

        return {
            "loss_cls": loss_cls,
            "loss_box_reg": loss_box_reg,
            "loss_centerness": loss_centerness,
            "loss_border_cls": loss_border_cls,
            "loss_border_reg": loss_border_reg,
        }

BAM(Border Alignment Module)

对于来自于FPN的特征,borderAlign需要5xC的border-sensitive的特征图,所以需要将通道先升维,提取完后再降维。为了“最干净”的验证Border feature的有效性,BAM中使用1x1conv来做升降维,几乎不增加模型的计算量。最后还原到256通道,来做最终的边界的分类和回归。

Pooling Size.

Border Align中Pooling Size是一个超参,论文中使用10; 每个Border分支都是被分割成N个点, 并且使用MAX_POOLING来计算特征值,提出的边界对齐可以自适应地利用边界极值点的代表性边界特征。(这个有点向双阶段的ROI ALIGH)

BorderAlign首先将每个边框细分为若干点,然后在每个边框上进行池化以提取边框特征。在边界对齐过程中,引入了一个新的超参数–池大小。比较了BorderAlign中不同池大小的检测性能。结果见表。3.当池大小等于0时,实验相当于迭代预测包围盒。实验结果表明,该算法在较大范围内对Pooling Size的取值具有较强的鲁棒性。由于较大的Pooling Size需要额外的计算.

其实pooling_size 在粗糙的bouding框中就分割成了N部分, 最终会确认每个border在哪个N中最大,取featura_map的极值点来确认处于哪个分割部分;再一次精细了框.

模型训练及预测

BorderDet的秉持一贯“朴素、简单且有效”,训练和前传都非常简单,没有hack什么trick。

  1. 分类损失函数就是focal_loss,超参和FCOS完全一致。
  2. 回归就是最简单的smoothl1,精修FCOS预测框的位置。
  3. 最后的总的损失函数就是把所有的加在一起,损失函数的权重都是1。

作者,也是评价了最近的网络进行对比(非论文内容):

在这里插入图片描述

论文中
在这里插入图片描述

如果小朋友门用单个GPU跑模型的话,请注意将word_size > 1 改成 word_size >= 1; 具体的有可能忘了,看看github 的issue, 有小宝贝有同样的问题.

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值