ATSS论文笔记

ATSS论文笔记

1 整体概述

核心思想:
Anchor-based与Anchor-free的本质区别在于其正负样本的分配方式不同。因此文章提出一种自适应训练样本选择的方法,从而使得网络能根据统计特征自适应选择正负样本。
RetinaNet和Fcos网络:

三大区别:

  • 锚框数目不同:Fcos相当于是每个锚点一个锚框
  • 正负样本分类不同:RetinaNet是利用IoU划分,而FCOS则是利用空间和尺度信息划分
  • 回归的起点不同:RetinaNet是利用锚框回归,而FCOS是利用锚点回归

这篇论文主要的贡献:

  • 分析了Anchor-based与Anchor-free的本质区别
  • 提出了一种自适应训练样本分配方法
  • 阐述了在图像中每个位置平铺多个锚点以检测对象是一个无用的操作

2 前人工作

Anchor-based与Anchor-free:

Anchor-based方法包括两阶段与一阶段。

Two-Stage网络:Fast-Rcnn是两阶段的起源,其可以分为一个候选区域生成网络(RPN)和一个区域预测网络。之后,提出了许多算法来提高其性能,包括架构重新设计和重新格式化、上下文和注意机制 、多尺度训练和测试、训练策略和损失函数, 特征融合和增强, 更好的提议和平衡。

One-Stage 网络:SSD的出现是单阶段网络的开始,推广了多尺度锚框预测方法,并在基于锚框进行分类和回归。一些新的方法不断被提出:融合来自不同层的上下文信息、从头开始训练、引入新的损失函数、锚点细化和匹配、架构重新设计、特征丰富和对齐。

基于关键点与基于中心:

基于关键点的方法:CornerNet是引入左上角和右下角的一对关键点,而CornerNet-Lite则是引入了Scade和Squeeze方法提高了速度,Grid RCNN则是预测具有位置敏感的Grid来作为定位对象,ExtremeNet预测四个极值点和一个中心点实现回归。CenterNet将CornerNet扩展为三元组提高精度,Reppoints使用点集作为回归目标。

基于中心点:DenseBox使用位于对象中心的填充圆来定义正数,然后预测从正数到对象边界框边界的四个距离以进行位置。GA-RPN将物体中心区域的像素定义为预测Faster R-CNN对象提议的位置、宽度和高度的积极因素。FFSAF将无锚分支与在线特征选择附加到RetinaNet。新添加的分支将对象的中心区域定义为正样本,通过预测到其边界的四个距离来定位它。FCOS将对象边界框内的所有位置视为四个距离的阳性和一个新的中心度分数来检测对象。CSP仅将对象框的中心点定义为阳性,以检测具有固定纵横比的行人。FoveaBox将对象中间部分的位置视为具有四个距离来执行检测的正样本。

3 本文创新

3.1 论证正负样本的选择是关键
  • 首先通过消融实验,对比了RetinaNet和FCOS,并逐步添加了FCOS上与Anchor-Free无关的一些策略
    在这里插入图片描述

实验明显表明了,在增加这些FCOS才有的策略后,两者之间的差距在不断缩小。

  • 其次验证是检测分类子任务中正负样本的划分带来的精度区别还是回归分支中关于锚点还是锚框带来的精度区别

正负样本的选择不同


FCOS:先在空间条件限制和空间维度上选择候选的正样本,再在尺度上确定最终的正负样本

RetinaNet: 利用IoU直接在空间和尺度上一次性确定最终的正负样本

回归方式不同


FCOS:利用锚点回归ltbr四个参数

RetinaNet: 利用锚框回归 △ x \triangle x x △ y \triangle y y △ w \triangle w w △ h \triangle h h四个参数
在这里插入图片描述
替换实验表明,如果在算法中分别固定用IoU和空间尺度限制的方式,不同的回归方式对精度几乎没有影响,证明了真正关键的是正负样本的分配机制,即是IoU还是空间尺度限制。

3.2 ATSS的提出与原理

ATSS算法:
在这里插入图片描述

如果有一个锚框被分配到多个真实框,则选择IOU更高的框作为分配对象

具体步骤:

  • 首先在每一个真实框周围的每一级FPN层中挑选具体真实框中心最近的k个Anchor,组成候选正样本集合;
  • 计算候选正样本集合到真实框g之间的IoU,计算IoU集合的均值、方差、并将均值+方差设置为阈值
  • 对于候选框进行筛选,大于阈值的且锚框中心点在真实框中的设置为正样本,加入正样本集合
  • 循环遍历每一个真实框,并重复上述步骤筛选真实的正样本
  • 负样本=总的Anchor - 正样本的Anchor

4 代码复现

				num_anchors_per_loc = len(self.cfg.MODEL.ATSS.ASPECT_RATIOS) * self.cfg.MODEL.ATSS.SCALES_PER_OCTAVE
    			# num_anchors_per_level:每级FPN上锚框的数量;[10000, 2500, 625, 169, 49]
                num_anchors_per_level = [len(anchors_per_level.bbox) for anchors_per_level in anchors[im_i]]

                # 计算一张图片上的bboxs与gt的iou,得到的size:[bboxs_num,gt_num]
                ious = boxlist_iou(anchors_per_im, targets_per_im)

                # 计算bboxs的中心点
                gt_cx = (bboxes_per_im[:, 2] + bboxes_per_im[:, 0]) / 2.0
                gt_cy = (bboxes_per_im[:, 3] + bboxes_per_im[:, 1]) / 2.0
                gt_points = torch.stack((gt_cx, gt_cy), dim=1)

                # 计算gts的中心点
                anchors_cx_per_im = (anchors_per_im.bbox[:, 2] + anchors_per_im.bbox[:, 0]) / 2.0
                anchors_cy_per_im = (anchors_per_im.bbox[:, 3] + anchors_per_im.bbox[:, 1]) / 2.0
                anchor_points = torch.stack((anchors_cx_per_im, anchors_cy_per_im), dim=1)

                # 计算中心点之间的L2距离 输出size:[bboxs_num,gt_num]
                distances = (anchor_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt()


                # 每一级FPN上,假设L级,则一共筛选k*l个候选正样本
                candidate_idxs = []
                star_idx = 0
                # 遍历每一级
                for level, anchors_per_level in enumerate(anchors[im_i]):
                    end_idx = star_idx + num_anchors_per_level[level]
                    distances_per_level = distances[star_idx:end_idx, :]
                    # 索引每一金字塔级别上的锚框与真实框的距离
                    topk = min(self.cfg.MODEL.ATSS.TOPK * num_anchors_per_loc, num_anchors_per_level[level])
                    # topk一般设为9
                    _, topk_idxs_per_level = distances_per_level.topk(topk, dim=0, largest=False)
                    # topk_idxs_per_level:每一金字塔级别上最接近真实框的k个锚框的索引
                    candidate_idxs.append(topk_idxs_per_level + star_idx)
                    # 金字塔上的序号是按照级累加索引的,所以下一级开始的是上一级的结束索引
                    star_idx = end_idx
                candidate_idxs = torch.cat(candidate_idxs, dim=0)
                # candidate_idxs:候选正样本的索引
                # candidate_idxs.size():[k*l, gts_num]
                
                candidate_ious = ious[candidate_idxs, torch.arange(num_gt)]
                # candidate_ious:候选正样本对应的iou
                # candidate_ious.size():[k*l, gts_num]
                iou_mean_per_gt = candidate_ious.mean(0) #计算均值
                # iou_mean_per_gt.size(): [gts_num];
                iou_std_per_gt = candidate_ious.std(0) #计算方差
                # iou_std_per_gt.size(): [gts_num];
                iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt #得到阈值
                # iou_thresh_per_gt.size(): [gts_num]
                is_pos = candidate_ious >= iou_thresh_per_gt[None, :]

                anchor_num = anchors_cx_per_im.shape[0]

                # 使得几个真实框对应的锚框铺成一维时,仍能够被索引到
                for ng in range(num_gt):
                    candidate_idxs[:, ng] += ng * anchor_num

                # 将几个真实框的锚框铺成一维
                e_anchors_cx = anchors_cx_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1)
                e_anchors_cy = anchors_cy_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1)
                # e_anchors_cx:锚框中心点的横坐标;e_anchors_cy:锚框中心点的纵坐标
                # e_anchors_cx.size(): [bboxs_num*gts_num];
                # e_anchors_cy.size(): [bboxs_num*gts_num]; 

                # 将几个真实框的候选正样本的索引铺成一维
                candidate_idxs = candidate_idxs.view(-1)

                # 筛选出中心点位于对应的真实框内的,作为候选正样本
                l = e_anchors_cx[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 0]
                t = e_anchors_cy[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 1]
                r = bboxes_per_im[:, 2] - e_anchors_cx[candidate_idxs].view(-1, num_gt)
                b = bboxes_per_im[:, 3] - e_anchors_cy[candidate_idxs].view(-1, num_gt)
                is_in_gts = torch.stack([l, t, r, b], dim=1).min(dim=1)[0] > 0.01

                # 候选正样本中高于对应的iou阈值,且中心点位于对应的真实框内
                is_pos = is_pos & is_in_gts

                # 如果一个锚框被多个真实框所选择,则其归于iou较高的真实框
               
                ious_inf = torch.full_like(ious, -INF).t().contiguous().view(-1)
                # ‘INF’是作者自己定义的值,INF = 100000000
                # ious_inf是经过原来iou转置过的
                # ious_inf.size(): [bboxs_num*gts_num];
                index = candidate_idxs.view(-1)[is_pos.view(-1)]
                # 得到候选正样本中高于对应的iou阈值,且中心点位于对应的真实框内的索引
                ious_inf[index] = ious.t().contiguous().view(-1)[index]
                # ious_inf中的正样本的iou赋予原本的iou,其它都赋为-INF
                ious_inf = ious_inf.view(num_gt, -1).t()
                anchors_to_gt_values, anchors_to_gt_indexs = ious_inf.max(dim=1)

                cls_labels_per_im = labels_per_im[anchors_to_gt_indexs]
                cls_labels_per_im[anchors_to_gt_values == -INF] = 0
                matched_gts = bboxes_per_im[anchors_to_gt_indexs]
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值