ATSS核心代码解析【Pytorch】

论文全称:Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training
论文地址:https://arxiv.org/abs/1912.02424
ATSS论文讲解可参考我写的这篇博客笔记:https://blog.csdn.net/chenzhoujian_/article/details/109144860
源码:https://github.com/sfzhang15/ATSS
ATSS核心代码:https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py

            elif self.cfg.MODEL.ATSS.POSITIVE_TYPE == 'ATSS':
                # 注意:以下的过程都是对于一张图片来操作的

                num_anchors_per_loc = len(self.cfg.MODEL.ATSS.ASPECT_RATIOS) * self.cfg.MODEL.ATSS.SCALES_PER_OCTAVE
                # num_anchors_per_loc:每个位置锚框的数量
                num_anchors_per_level = [len(anchors_per_level.bbox) for anchors_per_level in anchors[im_i]]
                # num_anchors_per_level:每个金字塔级别上锚框的数量;[10000, 2500, 625, 169, 49]

                # 计算一张图片上的锚框与真实框的iou,后面会用到
                ious = boxlist_iou(anchors_per_im, targets_per_im)
                # ious.size():[a, b];  a:一张图片上锚框的数量;b:一张图片上真实框的数量

                # 计算锚框的中心点
                gt_cx = (bboxes_per_im[:, 2] + bboxes_per_im[:, 0]) / 2.0
                gt_cy = (bboxes_per_im[:, 3] + bboxes_per_im[:, 1]) / 2.0
                gt_points = torch.stack((gt_cx, gt_cy), dim=1)

                # 计算真实框的中心点
                anchors_cx_per_im = (anchors_per_im.bbox[:, 2] + anchors_per_im.bbox[:, 0]) / 2.0
                anchors_cy_per_im = (anchors_per_im.bbox[:, 3] + anchors_per_im.bbox[:, 1]) / 2.0
                anchor_points = torch.stack((anchors_cx_per_im, anchors_cy_per_im), dim=1)

                # 计算锚框与真实框中心点之间的L2距离,‘None’的作用是广播
                distances = (anchor_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt()
                # distances.size():[a, b];  a:一张图片上锚框的数量;b:一张图片上真实框的数量


                # 在每一个金字塔级别上,根据L2距离选择其中心最接近真实框g的k个锚框;假设有L个金字塔级别,则每个真实框g有k×L个候选正样本
                ##################################################################################################################
                candidate_idxs = []
                star_idx = 0
                for level, anchors_per_level in enumerate(anchors[im_i]):
                    end_idx = star_idx + num_anchors_per_level[level]
                    distances_per_level = distances[star_idx:end_idx, :]
                    # 索引每一金字塔级别上的锚框与真实框的距离
                    topk = min(self.cfg.MODEL.ATSS.TOPK * num_anchors_per_loc, num_anchors_per_level[level])
                    # topk:超参数k,一般设为9
                    _, topk_idxs_per_level = distances_per_level.topk(topk, dim=0, largest=False)
                    # topk_idxs_per_level:每一金字塔级别上最接近真实框的k个锚框的索引
                    candidate_idxs.append(topk_idxs_per_level + star_idx)
                    # 这步操作使得后面金字塔级别上的锚框能够被索引
                    star_idx = end_idx
                candidate_idxs = torch.cat(candidate_idxs, dim=0)
                # candidate_idxs:候选正样本的索引;
                # candidate_idxs.size():[a, b];  a:k*L(有几个金字塔级别);b:一张图片上真实框的数量
                ##################################################################################################################


                # 计算iou的均值和标准差,得到每个真实框的iou阈值并进行比较
                ##################################################################################################################
                candidate_ious = ious[candidate_idxs, torch.arange(num_gt)]
                # candidate_ious:候选正样本对应的iou
                # candidate_ious.size(): 同candidate_idxs
                iou_mean_per_gt = candidate_ious.mean(0)
                # iou_mean_per_gt.size(): [a];  a:一张图片上真实框的数量
                iou_std_per_gt = candidate_ious.std(0)
                # iou_std_per_gt.size(): [a];  a:一张图片上真实框的数量
                iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt
                # iou_thresh_per_gt.size(): [a];  a:一张图片上真实框的数量
                is_pos = candidate_ious >= iou_thresh_per_gt[None, :]
                # ‘None’起到广播的作用
                ##################################################################################################################


                # 将正样本的中心点限制在真实框内
                ##################################################################################################################
                anchor_num = anchors_cx_per_im.shape[0]
                # anchor_num:一张图片上锚框的数量
                # 使得几个真实框对应的锚框铺成一维时,仍能够被索引到
                for ng in range(num_gt):
                    candidate_idxs[:, ng] += ng * anchor_num

                # 将几个真实框的锚框铺成一维
                e_anchors_cx = anchors_cx_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1)
                e_anchors_cy = anchors_cy_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1)
                # e_anchors_cx:锚框中心点的横坐标;e_anchors_cy:锚框中心点的纵坐标
                # e_anchors_cx.size(): [a];  a:一张图片上锚框的数量 × 一张图片上真实框的数量
                # e_anchors_cy.size(): [a];  a:一张图片上锚框的数量 × 一张图片上真实框的数量

                # 将几个真实框的候选正样本的索引铺成一维
                candidate_idxs = candidate_idxs.view(-1)

                # 筛选出中心点位于对应的真实框内的候选正样本
                # 这里的做法就是将锚框中心点的横坐标限制在真实框的Xmin与Xmax之间,纵坐标限制在Ymin与Ymax之间,可以自行画个坐标图帮助理解(画图时注意,原点在左上角,x轴以右为正方向,y轴以下为正方向)
                l = e_anchors_cx[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 0]
                t = e_anchors_cy[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 1]
                r = bboxes_per_im[:, 2] - e_anchors_cx[candidate_idxs].view(-1, num_gt)
                b = bboxes_per_im[:, 3] - e_anchors_cy[candidate_idxs].view(-1, num_gt)
                is_in_gts = torch.stack([l, t, r, b], dim=1).min(dim=1)[0] > 0.01

                # 候选正样本中高于对应的iou阈值,且中心点位于对应的真实框内
                is_pos = is_pos & is_in_gts
                ##################################################################################################################


                # 如果一个锚框被多个真实框所选择,则其归于iou较高的真实框
                ##################################################################################################################
                ious_inf = torch.full_like(ious, -INF).t().contiguous().view(-1)
                # ‘INF’是作者自己定义的值,INF = 100000000
                # ious_inf是经过原来iou转置过的
                # ious_inf.size(): [a];  a: 一张图片上的锚框数量 × 一张图片上的真实框数量
                index = candidate_idxs.view(-1)[is_pos.view(-1)]
                # 得到候选正样本中高于对应的iou阈值,且中心点位于对应的真实框内的索引
                ious_inf[index] = ious.t().contiguous().view(-1)[index]
                # ious_inf中的正样本的iou赋予原本的iou,其它都赋为-INF
                ious_inf = ious_inf.view(num_gt, -1).t()
                anchors_to_gt_values, anchors_to_gt_indexs = ious_inf.max(dim=1)
                ##################################################################################################################

                cls_labels_per_im = labels_per_im[anchors_to_gt_indexs]
                cls_labels_per_im[anchors_to_gt_values == -INF] = 0
                matched_gts = bboxes_per_im[anchors_to_gt_indexs]

  • 5
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值