FreeAnchor 原理与代码解析

论文 FreeAnchor: Learning to Match Anchors for Visual Object Detection

官方代码 https://github.com/zhangxiaosong18/FreeAnchor

        作者指出IoU-based label assignment对于acentric, slinder, crowded objects,其正负样本的分配可能效果不好。比如以下图的月亮为例,其中绿框是gt box,红框是anchor,在基于IoU的规则下可能将该anchor分配给该gt作为正样本,但从图中可以看到anchor box中只包含了月亮的一小部分,大部分都是黑色背景。这显然是不合理的。

因此作者提出了一种learning-to-match的方法,即在训练过程中让模型自己去学习挑选合适的anchor作为正负样本,抛弃基于IoU的这种hand-crafted方法。这种方法从三个方面来优化目标检测模型的学习过程。

  1. 高召回率
        对于每个gt,保证至少有一个anchor去匹配该gt,即至少有一个anchor作为正样本,其proposal负责预测该gt。

  2. 高精度
        对于定位差的anchor,模型要将其分类为背景。

  3. 兼容NMS
        即分类优先,分类score越高,其定位也应该越准。否则定位很好但分类score低的结果就被去除了。

为了同时满足这三个方面,作者将object-anchor matching设计成极大似然估计问题,在训练过程中使用模型的预测结果去寻找合适的anchor,反过来再推断模型的参数。        

对于一张输入图片\(I\),gt annotations定义为\(B\),其中一个gt box \(b_{i}\in B\)由类别标签\(b_{i}^{cls}\)和位置标签\(b_{i}^{loc}\)组成。在前向传播过程中,每个anchor \(a_{j}\in A\)在sigmoid函数后得到一个类别预测结果\(a_{j}^{cls}\in \mathbb{R}^{k}\),\(k\)是类别个数,在边框回归后得到一个位置的预测结果\(a_{j}^{loc}=\left \{ x,y,w,h \right \}\)。

训练过程中,基于IoU的分配规则会生成一个矩阵\(C_{ij}\in \left \{ 0,1 \right \}\)来定义gt \(b_{i}\)是否与anchor \(a_{j}\)匹配,当\(b_{i}\)和\(a_{j}\)的IoU大于设定阈值时,\(b_{i}\)和\(a_{j}\)匹配且\(C_{ij}=1\),否则\(C_{ij}=0\)。当一个anchor和多个gt的IoU都大于设定阈值时,取IoU最大的gt与该anchor匹配,即要保证每个anchor至多和一个gt匹配,即 \(\sum_{i}C_{ij}\in \left \{ 0,1 \right \},\forall a_{j}\in A\)。

定义\(A_{+}\subseteq A\)为\(\left \{ a_{j}|\sum _{i}C_{ij}=1 \right \}\),\(A_{-}\subseteq A\)为\(\left \{ a_{j}|\sum _{i}C_{ij}=0 \right \}\),目标检测模型的Loss通常定义成如下形式

其中\(\theta\)是待学习的模型参数,\(L_{ij}^{cls}(\theta )=BCE(a_{j}^{cls},b_{i}^{cls},\theta)\),\(L_{ij}^{loc}(\theta )=SmoothL1(a_{j}^{loc},b_{i}^{loc},\theta)\),\(L_{j}^{bg}(\theta )=BCE(a_{j}^{cls},\vec{0},\theta)\),\(\beta\)是正则化因子,\(bg\)指背景。

将上式Loss转换成似然概率,如下

其中\(P_{ij}^{cls}(\theta)\)和\(P_{j}^{bg}(\theta)\)表示分类置信度,\(P_{ij}^{loc}(\theta)\)代表定位置信度。减小(1)式中的loss等价于增大(2)式中的似然概率。

注意,式(2)中第一项本应如\(\prod _{a_{j}\in A_{+}}(e^{-\sum _{b_{i}\in B}C_{ij}L_{ij}^{cls}(\theta)}))\)所示,因为\(\sum_{b_{i}\in B}C_{ij}\)在\(j\)确定的情况下,只有一项等于1,其它都为0,因此可以将其移到\(e\)前面。

式(2)从MLE的角度同时考虑到了分类和定位的优化,但忽视了如何学习匹配矩阵\(C_{ij}\)。

Detection Customized Likelihood

因此作者提出了针对检测的似然概率Detection Customized Likelihood。首先,对于每个目标\(b_{i}\),挑选\(n\)个IoU最大的anchor \(A_{i}\subset A\)作为候选正样本,然后在maximizing detection customized likelihood的同时去学习匹配最优的anchor。

为了优化召回率,对于每个目标\(b_{i} \subset B\),需要保证至少有一个anchor \(a_{j}\subset A_{i}\),它的预测\(a_{j}^{cls}\)和\(a_{j}^{loc}\)和ground truth较为接近,然后去负责预测该gt。优化召回率对应的似然函数如下所示 

这个公式的含义是,对于每个目标\(i\),取anchor候选集\(A_{i}\)中分类和定位置信度乘积最大的那个,然后去优化这个anchor,即最大化这个乘积。

为了优化精度,需要把定位差的anchor分类为背景。对应的似然函数如下

其中\(P\left \{ a_{j}\in A_{-} \right \}=1-max_{i}P\left \{ a_{j}\rightarrow b_{i} \right \}\)是\(a_{j}\)没有和任一个gt匹配上的概率,\(P\left \{ a_{j}\rightarrow b_{i} \right \}\)是\(a_{j}\)正确预测\(b_{i}\)的概率。

这个公式的含义是,从定位的角度来看,当一个anchor属于背景,那么\(P\left \{ a_{j}\rightarrow b_{i} \right \}\)就比较大。从分类的角度我们希望将其分类到背景,即\(P_{j}^{bg}(\theta)\)大,这样\(1-P_{j}^{bg}(\theta)\)才能小,\(P\left \{ a_{j}\subset A_{-} \right \}(1-P_{j}^{bg}(\theta))\)才能小,精度才能变大。

为了和NMS兼容,\(P\left \{ a_{j}\rightarrow b_{i} \right \}\)应该满足下面三个性质

  1. \(P\left \{ a_{j}\rightarrow b_{i} \right \}\)应该是\(a_{j}\)和\(b_{i}\)的IoU即\(IoU_{ij}^{loc}\)的单调递增函数
  2. 当\(IoU_{ij}^{loc}\)小于阈值\(t\)时,\(P\left \{ a_{j}\rightarrow b_{i} \right \}\)应该趋近于0
  3. 对于一个gt \(b_{i}\),应该存在且仅存在一个\(a_{j}\)满足\(P\left \{ a_{j}\rightarrow b_{i} \right \}=1\)

饱和线性函数如下所示,可以满足上述3个性质

因此定义\(P\left \{ a_{j}\rightarrow b_{i} \right \}=Saturated\;linear(IoU_{ij}^{loc},t,max_{j}(IoU_{ij}^{loc}))\)

这样就满足上述优化召回率、优化精度、和NMS兼容三个条件,并且可以达到我们希望在训练过程中free object-anchor matching的要求。

Anchor Matching Mechanism

为了训练,再将上面的似然概率函数转换成Loss

其中max函数用来为每个目标\(b_{i}\)挑选最优的一个anchor,在训练过程中,从anchor候选正样本\(A_{i}\)中挑选出一个最优的用来更新模型权重\(\theta\)。

但在训练初期,所有anchor的置信度都比较小,置信度最高的anchor不一定是最匹配的。因此作者提出使用Mean-Max函数,定义如下

当训练不充分时,如下图所示,Mean-max类似于Mean函数,这意味着候选\(A_{i}\)中几乎所有anchor都参与训练了。随着训练的进行,某些anchor的置信度逐渐增加,Mean-max类似于Max函数,当训练充分时,最合适的一个anchor会从\(A_{i}\)中挑选出来匹配目标\(b_{i}\)。

把式(6)中的max函数替换为mean-max函数,添加权重因子\(w_{1},w_{2}\),式中的第二项应用focal loss,FreeAnchor最终的loss如下所示

其中\(X_{i}=\left \{ P_{ij}^{cls}(\theta)P_{ij}^{loc}(\theta)|a_{j}\subset A_{i} \right \}\)是对应anchor候选集\(A_{i}\)的似然概率集和。使用focal loss中的参数\(\alpha,\gamma\),设置\(w_{1}=\frac{\alpha}{\begin{Vmatrix}
B
\end{Vmatrix}}\),\(w_{2}=\frac{1-\alpha}{n\begin{Vmatrix}
B
\end{Vmatrix}}\),\(FL(x)=-x^{\gamma}log(1-x)\)。

定义好detection customized loss后,训练过程如下所示

代码

下面是mmdet中的实现,对中间一些输出加了一些注释

class FreeAnchorRetinaHead(RetinaHead):
    """FreeAnchor RetinaHead used in https://arxiv.org/abs/1909.02466.

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
        stacked_convs (int): Number of conv layers in cls and reg tower.
            Default: 4.
        conv_cfg (dict): dictionary to construct and config conv layer.
            Default: None.
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default: norm_cfg=dict(type='GN', num_groups=32,
            requires_grad=True).
        pre_anchor_topk (int): Number of boxes that be token in each bag.
        bbox_thr (float): The threshold of the saturated linear function. It is
            usually the same with the IoU threshold used in NMS.
        gamma (float): Gamma parameter in focal loss.
        alpha (float): Alpha parameter in focal loss.
    """  # noqa: W605

    def __init__(self,
                 num_classes,
                 in_channels,
                 stacked_convs=4,
                 conv_cfg=None,
                 norm_cfg=None,
                 pre_anchor_topk=50,
                 bbox_thr=0.6,
                 gamma=2.0,
                 alpha=0.5,
                 **kwargs):
        super(FreeAnchorRetinaHead,
              self).__init__(num_classes, in_channels, stacked_convs, conv_cfg,
                             norm_cfg, **kwargs)

        self.pre_anchor_topk = pre_anchor_topk
        self.bbox_thr = bbox_thr
        self.gamma = gamma
        self.alpha = alpha  # 0.5

    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             gt_bboxes_ignore=None):
        # cls_scores: [(1,180,38,38),(1,180,19,19),(1,180,10,10),(1,180,5,5),(1,180,3,3)]
        # bbox_preds: [(1,36,38,38),(1,36,19,19),(1,36,10,10),(1,36,5,5),(1,36,3,3)]
        # gt_bboxes: [tensor([[0.0000,   0.0000, 300.0000, 300.0000],
        #                     [0.0000,   0.0000, 207.5188, 200.0000]], device='cuda:0')]
        # gt_labels: [tensor([12, 14], device='cuda:0')]
        # img_metas:
        # gt_bboxes_ignore: None
        """Compute losses of the head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level
                Has shape (N, num_anchors * num_classes, H, W)
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level with shape (N, num_anchors * 4, H, W)
            gt_bboxes (list[Tensor]): each item are the truth boxes for each
                image in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]  # [(38,38),(19,19),(10,10),(5,5),(3,3)]
        assert len(featmap_sizes) == self.prior_generator.num_levels

        anchor_list, _ = self.get_anchors(featmap_sizes, img_metas)
        # len(anchor_list)=batch_size=1, len(anchor_list[0])=out_levels_num=5
        # torch.Size([12996, 4])  38*38*9
        # torch.Size([3249, 4])
        # torch.Size([900, 4])
        # torch.Size([225, 4])
        # torch.Size([81, 4])
        anchors = [torch.cat(anchor) for anchor in anchor_list]  # [(17451,4)]

        # concatenate each level
        cls_scores = [
            cls.permute(0, 2, 3,
                        1).reshape(cls.size(0), -1, self.cls_out_channels)
            for cls in cls_scores
        ]
        # torch.Size([1, 12996, 20])
        # torch.Size([1, 3249, 20])
        # torch.Size([1, 900, 20])
        # torch.Size([1, 225, 20])
        # torch.Size([1, 81, 20])
        bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4)
            for bbox_pred in bbox_preds
        ]
        # torch.Size([1, 12996, 4])
        # torch.Size([1, 3249, 4])
        # torch.Size([1, 900, 4])
        # torch.Size([1, 225, 4])
        # torch.Size([1, 81, 4])
        cls_scores = torch.cat(cls_scores, dim=1)  # (1, 17451, 20)
        bbox_preds = torch.cat(bbox_preds, dim=1)  # (1, 17451, 4)

        cls_prob = torch.sigmoid(cls_scores)
        box_prob = []
        num_pos = 0
        positive_losses = []
        for _, (anchors_, gt_labels_, gt_bboxes_, cls_prob_,
                bbox_preds_) in enumerate(
                    zip(anchors, gt_labels, gt_bboxes, cls_prob, bbox_preds)):

            with torch.no_grad():  # 注意这里要取消梯度
                if len(gt_bboxes_) == 0:
                    image_box_prob = torch.zeros(
                        anchors_.size(0),
                        self.cls_out_channels).type_as(bbox_preds_)
                else:
                    # box_localization: a_{j}^{loc}, shape: [j, 4]
                    pred_boxes = self.bbox_coder.decode(anchors_, bbox_preds_)  # (17451,4),(17451,4) -> (17451,4)

                    # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
                    object_box_iou = bbox_overlaps(gt_bboxes_, pred_boxes)  # (2,4),(17451,4) -> (2,17451)

                    # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
                    t1 = self.bbox_thr  # 0.6
                    t2 = object_box_iou.max(
                        dim=1, keepdim=True).values.clamp(min=t1 + 1e-12)  # (2,1), tensor([[0.7288],[0.6268]])
                    object_box_prob = ((object_box_iou - t1) /
                                       (t2 - t1)).clamp(
                                           min=0, max=1)  # (2,17451)

                    # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
                    num_obj = gt_labels_.size(0)  # 2
                    indices = torch.stack([
                        torch.arange(num_obj).type_as(gt_labels_), gt_labels_  # tensor([12,14],device='cuda:0')
                    ],
                                          dim=0)  # (2,2), tensor([[0,1],[12,14]])

                    object_cls_box_prob = torch.sparse_coo_tensor(
                        indices, object_box_prob)  # (2,15,17451)
                    # cj
                    # tmp = object_cls_box_prob.to_dense()
                    # import numpy as np
                    # for i in range(15):
                    #     print(np.sum(tmp.cpu().numpy()[:, i, :]))
                    # exit()

                    # image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
                    """
                    from "start" to "end" implement:
                    image_box_iou = torch.sparse.max(object_cls_box_prob,
                                                     dim=0).t()

                    """
                    # start
                    box_cls_prob = torch.sparse.sum(
                        object_cls_box_prob, dim=0).to_dense()  # (15,17451)
                    # 若两个gt属于不同类别,相加时总有一个值为0。
                    # 若两个gt属于同一类别,但object_box_prob > 0的anchor完全错开,相加时也总有一个值为0。
                    # 只有当两个gt属于同一类别,并且同一个anchor与两个gt的object_box_prob都大于0时,相加时和才会发生变化。
                    # 但没有关系,因为这里是求object_box_prob > 0的anchor的位置。这个anchor可能与同一类别的两个gt的object_box_prob都大于0,也可能与不同类别的两个gt的object_box_prob都大于0。

                    indices = torch.nonzero(box_cls_prob, as_tuple=False).t_()  # (23,2)->(2,23)
                    # tensor([[   12,    12,    12,    14,    14,    14,    14,    14,    14,    14,
                    #             14],
                    #         [16765, 16855, 16945, 13976, 14146, 14147, 14317, 14318, 14488, 14489,
                    #          16584]], device='cuda:0')

                    if indices.numel() == 0:  # 2*23=46
                        image_box_prob = torch.zeros(
                            anchors_.size(0),
                            self.cls_out_channels).type_as(object_box_prob)
                    else:
                        nonzero_box_prob = torch.where(
                            (gt_labels_.unsqueeze(dim=-1) == indices[0]),  # (2)->(2,1) == (15) -> (2,15)
                            object_box_prob[:, indices[1]],  # (2,17451)[:,(15)] -> (2,15)
                            torch.tensor([
                                0
                            ]).type_as(object_box_prob)).max(dim=0).values  # (2,15)->(15)
                        # 取max是因为可能存在同一个anchor与两个gt的object_box_prob都大于0,取大的那个

                        # print(indices.flip([0]))
                        # tensor([[16655, 16664, 16745, 16748, 16754, 16757, 16835, 16838, 16847, 16486,
                        #          16564, 16572, 16573, 16575, 16576, 16663, 16666],
                        #         [   12,    12,    12,    12,    12,    12,    12,    12,    12,    14,
                        #             14,    14,    14,    14,    14,    14,    14]], device='cuda:0')

                        # upmap to shape [j, c]
                        image_box_prob = torch.sparse_coo_tensor(
                            indices.flip([0]),
                            nonzero_box_prob,
                            size=(anchors_.size(0),
                                  self.cls_out_channels)).to_dense()  # (17451,20)
                    # end
                box_prob.append(image_box_prob)

            # construct bags for objects
            match_quality_matrix = bbox_overlaps(gt_bboxes_, anchors_)  # (2,4),(17451,4) -> (2,17451)
            _, matched = torch.topk(
                match_quality_matrix,
                self.pre_anchor_topk,  # 50
                dim=1,
                sorted=False)  # (2,50)
            del match_quality_matrix

            # matched_cls_prob: P_{ij}^{cls}
            matched_cls_prob = torch.gather(
                cls_prob_[matched], 2,  # (17451,20)[(2,50)] -> (2,50,20)
                gt_labels_.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,  # tensor([12,14]), (2)->(2,1,1)->(2,50,1)
                                                 1)).squeeze(2)  # (2,50,1)->(2,50)
            # exp(-BCE(matched_cls_prob)),这里exp和-BCE抵消了,所以还是matched_cls_prob

            # matched_box_prob: P_{ij}^{loc}
            matched_anchors = anchors_[matched]  # (17451,4)[(2,50)] -> (2,50,4)
            matched_object_targets = self.bbox_coder.encode(
                matched_anchors,
                gt_bboxes_.unsqueeze(dim=1).expand_as(matched_anchors))  # (2,4)->(2,1,4)->(2,50,4)

            loss_bbox = self.loss_bbox(  # SmoothL1Loss
                bbox_preds_[matched],  # (17451,4)[(2,50)] -> (2,50,4)
                matched_object_targets,  # (2,50,4)
                reduction_override='none').sum(-1)  # (2,50,4)->(2,50)
            matched_box_prob = torch.exp(-loss_bbox)  # (2,50)

            # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
            num_pos += len(gt_bboxes_)  # 2
            positive_losses.append(
                self.positive_bag_loss(matched_cls_prob, matched_box_prob))  # (2,50),(2,50) -> (2)
        positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos)  # 一个值,torch.Size([])

        # box_prob: P{a_{j} \in A_{+}}
        box_prob = torch.stack(box_prob, dim=0)

        # negative_loss:
        # \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
        # (1,17451,20),(1,17451,20)
        negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max(
            1, num_pos * self.pre_anchor_topk)

        # avoid the absence of gradients in regression subnet
        # when no ground-truth in a batch
        if num_pos == 0:
            positive_loss = bbox_preds.sum() * 0

        losses = {
            'positive_bag_loss': positive_loss,
            'negative_bag_loss': negative_loss
        }
        return losses

    def positive_bag_loss(self, matched_cls_prob, matched_box_prob):
        """Compute positive bag loss.

        :math:`-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )`.

        :math:`P_{ij}^{cls}`: matched_cls_prob, classification probability of matched samples.

        :math:`P_{ij}^{loc}`: matched_box_prob, box probability of matched samples.

        Args:
            matched_cls_prob (Tensor): Classification probability of matched
                samples in shape (num_gt, pre_anchor_topk).
            matched_box_prob (Tensor): BBox probability of matched samples,
                in shape (num_gt, pre_anchor_topk).

        Returns:
            Tensor: Positive bag loss in shape (num_gt,).
        """  # noqa: E501, W605
        # bag_prob = Mean-max(matched_prob)
        matched_prob = matched_cls_prob * matched_box_prob  # (2,50)*(2,50)->(2,50)
        weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None)  # (2,50)
        weight /= weight.sum(dim=1).unsqueeze(dim=-1)  # (2,50)->(2)->(2,1), (2,50)
        bag_prob = (weight * matched_prob).sum(dim=1)  # (2,50)->(2)
        # positive_bag_loss = -self.alpha * log(bag_prob)
        return self.alpha * F.binary_cross_entropy(
            bag_prob, torch.ones_like(bag_prob), reduction='none')

    def negative_bag_loss(self, cls_prob, box_prob):
        """Compute negative bag loss.

        :math:`FL((1 - P_{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}))`.

        :math:`P_{a_{j} \in A_{+}}`: Box_probability of matched samples.

        :math:`P_{j}^{bg}`: Classification probability of negative samples.

        Args:
            cls_prob (Tensor): Classification probability, in shape
                (num_img, num_anchors, num_classes).
            box_prob (Tensor): Box probability, in shape
                (num_img, num_anchors, num_classes).

        Returns:
            Tensor: Negative bag loss in shape (num_img, num_anchors, num_classes).
        """  # noqa: E501, W605
        prob = cls_prob * (1 - box_prob)  # cls_prob就是1-P_{j}^{bg}
        # There are some cases when neg_prob = 0.
        # This will cause the neg_prob.log() to be inf without clamp.
        prob = prob.clamp(min=EPS, max=1 - EPS)
        negative_bag_loss = prob**self.gamma * F.binary_cross_entropy(
            prob, torch.zeros_like(prob), reduction='none')
        return (1 - self.alpha) * negative_bag_loss

参考

FreeAnchor:令anchor自由匹配标签的策略(附源码实现) - 简书

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值