Focal Loss for Dense Object Detection

作为2017年ICCV的最佳学生论文,该文章主要提出了一种新的交叉熵损失函数,用于抑制one-stage在训练过程中的易分样本(包括易分正样本与易分负样本)的在loss中的比重,从而挖掘了困难样本对于detector的贡献。提高了detector的性能,为此本文设计了retinanet来验证新的损失函数的作用,通过实验结果表示,取得了不错的效果。

论文链接:https://arxiv.org/abs/1708.02002v2

本文将会详细的介绍:

  1. focal loss的原理
  2. 以mmdection中retinanet的代码为参考,简述整个retinanet从anchor的生成到focal oss的应用,一直到网络的输出的全过程。
  3. 讲述test的全过程。

1.Focal Loss

作者认为类别不均衡是one-stage效果不好的主要原因,two-stage的效果之所以好有以下两个原因:

  1. proposal stage即RPN阶段过滤了大量的背景样本
  2. 在RPN之后的阶段,通过让正样本:负样本=1:3之类的方式,或者困难样本挖掘(OHEM)保持了一种前景与背景之间的平衡

但是对于one-stage的来说,开始的阶段会产生大量的bbox,导致训练过程被易分类的样本所主导。

首先来看二分类的交叉熵CE:

p在[0,1]之间代表样本属于该类的可能性,y等于1或者-1表示ground-truth class.为了简化的写法,定义:

所以CE(p,y)=CE(pt)=-log(pt)

为了平衡正负样本之间的比例(正样本占比少),focal loss引入一个权重因子\alpha

为了让detector去挖掘困难样本的重要性,定义focal loss如下:

这里的\gamma都是取大于0的值,modulating factor为(1-p_{t})^{\gamma }, 可以看出当pt很大,而且确实是正样本的时候,modulating factor就会变得很小,此时易分正样本占的loss就会变得很小,然后如果pt比较小,就会变成困难正样本,此时modulating factor就不会那么小,从而相当于加大了困难正样本占loss的比重。当该样本属于负样本的时候按照公式(2),pt=(1-p),所以此时modulating factor变为p^{\gamma},当p很大的时候表示为难分负样本,所以modulating factor会变得较大,而当p很小的时候,modulating factor就会很小。

\gamma值为0的时候,focal就会变成CE

由图中可以看出modulating factor减少了简单样本的loss权重,扩大了损失函数取很小的值的p的范围变大了。比如当\gamma=2的时候,pt=0.9的损失会比CE小100倍,pt=0.968的时候比CE小1000多倍。这反过来增加了困难样本的重要性。最后在加上\alpha -balance之后:

实验表明加\alpha对于精度会有一些提升。实验中选择\alpha =0.25;\gamma =2\gamma =2。最后放一张结果实验图,从图中可以看出精度确实提高了不少

附一个从求导的角度看focal loss链接:https://blog.csdn.net/leviopku/article/details/89816408

 

2.retinanet

为了验证focal loss的有效性,作者专门设计了一个简单的one-stage的网络结构。下面将会详细的阐述网络的结构以及整个检测train与test的流程。结构与流程都是通过阅读mmdetection的源码所得。

2.1 base_anchor的生成

对于one-stage的detector来说,首先就是在原图中选出初始的若干base_anchor。在mmdetection中选取stride为[8,16,32,64,128],每一个stride按照三种ratio[0.5,1,2],三种scale 4*2^{i}(i从0取到2)这两个参数分别对应mmdetection中config配置文件retinanet_r101_fpn_1x.py的octave_base_scale=4, scales_per_octave=3, anchor_ratios=[0.5, 1.0, 2.0], anchor_strides=[8, 16, 32, 64, 128]。详细过程如下,以anchor_stride=8为例:

    def gen_base_anchors(self):
        w = self.base_size #这里的base_size就是anchor_strides的大小
        h = self.base_size
        if self.ctr is None:
            x_ctr = 0.5 * (w - 1)
            y_ctr = 0.5 * (h - 1)
        else:
            x_ctr, y_ctr = self.ctr

        h_ratios = torch.sqrt(self.ratios) #ration对应三种比例[0.5,1,2],之所以开更号是为了让ws,hs的比例等于ratio
        w_ratios = 1 / h_ratios
        if self.scale_major:
            ws = (w * w_ratios[:, None] * self.scales[None, :]).view(-1)#这里的scale就是上文的2^{i/3}*4 i=0,1,2
            hs = (h * h_ratios[:, None] * self.scales[None, :]).view(-1)
        else:
            ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1)
            hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1)

        # yapf: disable
        base_anchors = torch.stack(
            [
                x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
                x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
            ],
            dim=-1).round()
        # yapf: enable

        return base_anchors

相当于在8*8的格子里,以中心为base_anchor的中心,按照三种scale,三种ratio产生九个anchor。产生了base_anchor之后该怎么映射到原图中去呢,方法就是相当于原图被划分为了若干个8*8的格子,只需要将刚刚产生的base_anchor通过shift到原图就可以了

    def grid_anchors(self, featmap_size, stride=16, device='cuda'):
        base_anchors = self.base_anchors.to(device)

        feat_h, feat_w = featmap_size #这里的featmap的尺寸就是原图尺寸除以stride之后的,其实就是表示原图被划分为stride*stride大小的格子之后的行数与列数
        shift_x = torch.arange(0, feat_w, device=device) * stride #求出base_anchor在每一行中的偏移
        shift_y = torch.arange(0, feat_h, device=device) * stride#求出base_anchor在每一行中的偏移
        shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
        shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
        shifts = shifts.type_as(base_anchors)
        # first feat_w elements correspond to the first row of shifts
        # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
        # shifted anchors (K, A, 4), reshape to (K*A, 4)

        all_anchors = base_anchors[None, :, :] + shifts[:, None, :] #base_anchor加上偏移量
        all_anchors = all_anchors.view(-1, 4)
        # first A rows correspond to A anchors of (0, 0) in feature map,
        # then (0, 1), (0, 2), ...
        return all_anchors

形象的解释如下图:

上图就表示原图被划分为若干个8*8的格子,格子中的括号中的数字表示base_anchor(x1,y1,x2,y2)对应需要平移的大小。

每一种stride的大小都可以生成一组9*(H/stride)*(W/stirde)数目的base_anchors,其中9表示三种scale,三种ratio,H,W分别表示原图的高,宽。得到这些anchor之后,我们其实可以计算出这些anchor覆盖到原图的面积范围,最小面积32=8*4*2^0(最小的stride,与最小的scale), 最大的面积为813=128*4*2^(2/3) (最大的stride与最大的scale)

2.2 给base_anchors分配gt

按照上述步骤生成了base_anchors之后,去掉那些超过图像边界的框之后,便可以对每个base_anchor分配gt还有label,分配的过程如下:假设总anchors的数目为m,gt数目为3

首先求与每个anchor IOU最大的gt[1,m]:

  • IOU>pos_iou_thr 将该anchor分配为对应gt的编号
  • IOU<neg_iou_thr 将该anchor分配值为0
  • neg_iou_thr<IOU<pos_iou_thr 这些anchor赋值为-1,在训练过程中不做考虑

接着计算与每个gt IOU最大的anchor[3,1] 将这些anchor分配为对应的gt编号,这样做的目的是为了保证每个gt至少被分配到了一个正样本。

最后将正样本的label赋值为对应gt的对应类label。mmdetection中对应的代码如下:

    def assign_wrt_overlaps(self, overlaps, gt_labels=None):
        """
        This method assign a gt bbox to every bbox (proposal/anchor), each bbox
        will be assigned with -1, 0, or a positive number. -1 means don't care,
        0 means negative sample, positive number is the index (1-based) of
        assigned gt.
        The assignment is done in following steps, the order matters.

        1. assign every bbox to -1
        2. assign proposals whose iou with all gts < neg_iou_thr to 0
        3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
           assign it to that bbox
        4. for each gt bbox, assign its nearest proposals (may be more than
           one) to itself

        Assign w.r.t. the overlaps of bboxes with gts.

        Args:
            overlaps (Tensor): Overlaps between k gt_bboxes and n bboxes,
                shape(k, n).
            gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).

        Returns:
            :obj:`AssignResult`: The assign result.
        """
        if overlaps.numel() == 0:
            raise ValueError('No gt or proposals')

        num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)

        # 1. assign -1 by default
        assigned_gt_inds = overlaps.new_full((num_bboxes, ),
                                             -1,
                                             dtype=torch.long)

        # for each anchor, which gt best overlaps with it
        # for each anchor, the max iou of all gts
        max_overlaps, argmax_overlaps = overlaps.max(dim=0)
        # for each gt, which anchor best overlaps with it
        # for each gt, the max iou of all proposals
        gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)

        # 2. assign negative: below
        if isinstance(self.neg_iou_thr, float):
            assigned_gt_inds[(max_overlaps >= 0)
                             & (max_overlaps < self.neg_iou_thr)] = 0
        elif isinstance(self.neg_iou_thr, tuple):
            assert len(self.neg_iou_thr) == 2
            assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
                             & (max_overlaps < self.neg_iou_thr[1])] = 0

        # 3. assign positive: above positive IoU threshold
        pos_inds = max_overlaps >= self.pos_iou_thr
        assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1

        # 4. assign fg: for each gt, proposals with highest IoU
        for i in range(num_gts):
            if gt_max_overlaps[i] >= self.min_pos_iou:
                if self.gt_max_assign_all:
                    max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
                    assigned_gt_inds[max_iou_inds] = i + 1
                else:
                    assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1

        if gt_labels is not None:
            assigned_labels = assigned_gt_inds.new_zeros((num_bboxes, ))
            pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
            if pos_inds.numel() > 0:
                assigned_labels[pos_inds] = gt_labels[
                    assigned_gt_inds[pos_inds] - 1]
        else:
            assigned_labels = None

        return AssignResult(
            num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)

截止到现在,都跟具体的网络结构没有任何关系。其实网络最后输出的预测的bbox的值指的是base_anchors与gt之间的偏移量,

其中gx,gy,gw,gh分别表示gt的中心坐标与宽高,b表示base_anchor的中心坐标与宽高,\delta才是网络需要学习的回归值。最后网络的输出\delta在与base_anchor经过上述公式相反的变换得出最终的预测结果。

2.3 网络结构

首先backbone采用的是reanet101,neck为FPN,在加上retinanet的head。结构图如下:

其中lateral_conv是为了减少特征图的维数,将其统一到256; fpn_conv是为了减少上采样造成的影响。lateral_conv与fpn_conv都是每层都有一个对应的,然后后面的reg_conv,retina_reg则是所有层共享的(cls_conv,retina_cls同理).

最后经过fpn_conv有五个对应的feature map。

以stirde=8为例,这里只看fpn输出的第一个feature map最后经过retinanet_head即上图中的class subnet与box subnet之后分别得到分类预测cls_score[1,27,b,b], bbox偏移预测bbox_pred[1,36,b,b],这里假设是三分类的任务,这里的b其实就是原图大小处于stride之后的结果,27表示每个点有9个anchor,每个anchor属于三种类别的可能性(至于这里为什么没有考虑背景类,只因为后面的focal loss的计算是按照多个二分类交叉熵累加而成),36表示每个点有9个bbox,每个bbox用四个坐标值表示。

2.4 loss的计算

2.4.1 分类损失loss_cls的计算

focal loss就是用于计算loss_cls,达到抑制易分样本,挖掘困难样本的目的。

对于retinanet最后输出的cls_score[1,27,b,b],首先resize为[b*b*9,3] 然后对于每一列分别计算,比如现在计算第一行的第一列,先拿第一行,第一列的第一个值p1,经过p=sigmoid(p1)之后判断是正样本还是负样本,如果是正样本侧让c1=1,否则让c2=1,

                                      loss\_cls=-c1*\alpha *(1-p)^{\gamma }*log(p) -c2*(1-\alpha) *(p)^{\gamma }*log(1-p)

然后依次计算每一行,得到Loss_cls[b*b*9,3], 最后的Loss_cls = Loss_cls*label_weight。其中label_weight[pos_inds] = 1, label_weight[neg_inds]=1,其余既不是正样本也不是负样本的值为0,这里乘以label_weight的目的是为了只计算正负样本(其他既不是正样本也不是负样本的不计算),之后每一行的loss_cls累加起来,得到Loss_cls[b*b*9,1],最后得到的Loss_cls = Loss_cls/num_pos(正样本的数目)

2.4.2 bbox损失reg_loss的计算

对于retinanet最后输出的cls_score[1,36,b,b],首先resize为pred[b*b*9,4], 然后每一行计算与\delta之间的L1距离

                                                              reg_loss = L1(pred,\delta)*box_weight

其中box_weight[pos_inds]=1,其余值为0,就是为了保证只计算正样本对应的reg_loss,mmdetection中用到的L1损失跟普通的优点不一样,具体如下:

                                                                      L1 = \left\{\begin{matrix} 0.5*d^{2} & d < \beta \\ d-0.5\beta & d \geq \beta \end{matrix}\right.

其中d表示两个之间的距离,\beta=0.11,最后的到的reg_loss还要除以num_pos(正样本的数目)(为什么和上面的分类损失一样,这里不用考虑分类损失与回归损失的平衡吗??????)

2.4 test

1.首先对于FPN的每一个level按照输出的retina_cls排序选出前num_pre=1000个retina_reg中的bbox,以第一个level为例,retina_cls[1,27,b,b],retina_reg[1,36,b,b]。首先resize为retina_cls[b*b*9,3],retina_reg为[b*b*9,4]。 然后以每一行的最大值进行排序选出前1000个。接着对应retina_reg中选出该1000个bbox_pred, anchor中选出1000个anchors,接着该1000个bbox与anchor变换为最后的bboxs。对每一个level进行如上操作。得到mlvl_bboxes[m,4], 表示有m个bboxs,mlvl_scores[m,4]表示有m个位置,每个位置对应四个类的概率目标类+背景类)

2.对mlvl_scores[m,4]的每一列,即分别对每一类(背景类不管)选出大于score_thr=0.05的行,用这些行选出对应mlvl_bboxes[m,4]中的bboxs,接着用nms=dict(type='nms', iou_thr=0.5)进行NMS,得到最后的bboxs,为每个bboxs分配目标类label。这样对每一类选出相应的bboxs之后,如果在所有选出的bboxs按照得分选出前max_num个bboxs。

 

3.其他

  1. 在传统的CE损失中,虽然易分样本的loss很低,但是实验中表明大部分的anchor对应的都是简单样本。focal loss其实就是极大的降低了训练过程中易分样本(易分正样本,易分负样本)在loss中的比重,从而使得训练更加关注于那些困难样本。
  2. 实验中用focal loss的时候用的是多个二分类累加,相当于是BCEWithlogitsloss(=sigmoid+focal loss), 而不是多分类CrossEntropyLoss。为什么会这样笔者现在还没弄懂。。。。。。
  3. 在训练的时候有一个很有意思的点就是,在训练的初始化参数的时候All new conv layers except the final one in the RetinaNet subnets are initialized with bias b = 0 and a Gaussian weight fill with σ = 0.01. For the final conv layer of the classification subnet, we set the bias initialization to b = − log((1 − π)/π), 这样做的目的为了让训练开始的时候loss不会被大量的负样本的bbox带偏,进而会发现loss爆炸的情况。具体分析参见:https://zhuanlan.zhihu.com/p/63626711,我觉得写得非常好。
  4. 作者在文中对比了OHEM,得到的实验图如下:

从图中可以看到OHEM 1:3反而比单纯的OHEM的AP要低,这好像说明正负样本不平衡不是最核心的因素,而是由这个因素导出的easy example dominant的问题(参考:https://www.zhihu.com/question/63581984/answer/210832009),不过比较遗憾的是,作者并没有给出OHEM会让结果变差的一个合理解释,这其实也是很值得深挖的一点。

4.

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值