【mmdetection源码解读(二)】RPN网络

以下仅为个人理解,若有不正之处还请指出,欢迎交流!
一、整体过程
  • mmdet/models/detectors/two_stage.py中的部分代码
        if self.with_rpn:  # rpn_head
            rpn_outs = self.rpn_head(x)   # x为feature map
            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                          self.train_cfg.rpn)
            rpn_losses = self.rpn_head.loss(
                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
            losses.update(rpn_losses)
#---------------------------分割线---------------------------------------
            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)  
            proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
        else:
            proposal_list = proposals 
二、详细解读
1.RPN(rpn_head)网络结构
  • 网络结构示意图如下图所示:
    在这里插入图片描述
  • 代码实现
class RPNHead(AnchorHead):  # 继承AnchorHead类
    def __init__(self, in_channels, **kwargs):
        super(RPNHead, self).__init__(2, in_channels, **kwargs)   # rpn只是实现二分类,2-->num_classes
                     
    def _init_layers(self):
        self.rpn_conv = nn.Conv2d(
            self.in_channels, self.feat_channels, 3, padding=1)  # feature map首先经过一个3×3卷积层,特征图尺寸不变
        self.rpn_cls = nn.Conv2d(self.feat_channels,  
                                 self.num_anchors * self.cls_out_channels, 1)   # 1×1卷积层 由于rpn是二分类,所以cls_out_channels=1即可
        # self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
        self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)   # 1×1卷积层  4对应4个坐标偏移量

    def init_weights(self):
        normal_init(self.rpn_conv, std=0.01)
        normal_init(self.rpn_cls, std=0.01)
        normal_init(self.rpn_reg, std=0.01)

    def forward_single(self, x):   # 结构与示意图对应
        x = self.rpn_conv(x)
        x = F.relu(x, inplace=True)
        rpn_cls_score = self.rpn_cls(x) 
        rpn_bbox_pred = self.rpn_reg(x)
        return rpn_cls_score, rpn_bbox_pred  # anchor box
2.RPN_Loss
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
  • rpn_head.py中
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        losses = super(RPNHead, self).loss(    # 用于调用父类的一个方法,这里的父类为anchor_head
            cls_scores,
            bbox_preds,
            gt_bboxes,
            None,
            img_metas,
            cfg,
            gt_bboxes_ignore=gt_bboxes_ignore)
        return dict(
            loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
  • anchor_head.py中
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,    # None
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] # 特征图大小
        assert len(featmap_sizes) == len(self.anchor_generators)
        device = cls_scores[0].device
        
        anchor_list, valid_flag_list = self.get_anchors(   # 获取所有anchor box
            featmap_sizes, img_metas, device=device)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1  
        cls_reg_targets = anchor_target(   # 对anchor区分正负样本并采样用于rpn训练,返回真实标签和坐标偏移量等信息
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,   # 见下方
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
    def loss_single(self, cls_score, bbox_pred, labels, label_weights,   # 这里都是anchor box的相关信息,cls_score和bbox_pred为rpn网络的预测值
                    bbox_targets, bbox_weights, num_total_samples, cfg):
        # classification loss
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)
        cls_score = cls_score.permute(0, 2, 3,   
                                      1).reshape(-1, self.cls_out_channels)
        loss_cls = self.loss_cls(  
            cls_score, labels, label_weights, avg_factor=num_total_samples)
        # regression loss
        bbox_targets = bbox_targets.reshape(-1, 4)
        bbox_weights = bbox_weights.reshape(-1, 4)
        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
        loss_bbox = self.loss_bbox(
            bbox_pred,
            bbox_targets,
            bbox_weights,
            avg_factor=num_total_samples)
        return loss_cls, loss_bbox    # 分别计算得到分类和回归loss
  • 更多细节不再展开,概括地讲,RPN网络会根据设定的尺度和纵横比生成大量anchor box(anchor box介绍可参考理解anchor box究竟是如何生成的),通过anchor box与gt box之间的IoU值来判定anchor box的真实标签,并进行采样,例如一个mini-batch选择256个anchor,设定正负样本比例1:1,再计算标记为正样本的anchor box相对于gt box的真实坐标偏移量。然后通过训练使loss收敛,以使RPN网络的预测结果更接近真实结果。

3.生成区域建议候选框(proposals)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
  • anchor_head.py中
    def get_bboxes(self,
                   cls_scores,
                   bbox_preds,
                   img_metas,
                   cfg,
                   rescale=False):
       
        assert len(cls_scores) == len(bbox_preds)
        num_levels = len(cls_scores) 
        device = cls_scores[0].device
        mlvl_anchors = [       # 需再次获取所有的anchor
            self.anchor_generators[i].grid_anchors(
                cls_scores[i].size()[-2:],   
                self.anchor_strides[i],
                device=device) for i in range(num_levels)
        ]
        result_list = []
        for img_id in range(len(img_metas)):
            cls_score_list = [
                cls_scores[i][img_id].detach() for i in range(num_levels)   
            ]
            bbox_pred_list = [
                bbox_preds[i][img_id].detach() for i in range(num_levels)
            ]
            img_shape = img_metas[img_id]['img_shape']
            scale_factor = img_metas[img_id]['scale_factor']
            proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,  # 获取区域建议候选框,具体见下方
                                               mlvl_anchors, img_shape,
                                               scale_factor, cfg, rescale)
            result_list.append(proposals)
        return result_list
  • rpn_head.py中
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          mlvl_anchors,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        mlvl_proposals = []
        for idx in range(len(cls_scores)):
            rpn_cls_score = cls_scores[idx]
            rpn_bbox_pred = bbox_preds[idx]
            assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
            rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
            if self.use_sigmoid_cls:
                rpn_cls_score = rpn_cls_score.reshape(-1)
                scores = rpn_cls_score.sigmoid()   # 二分类
            else:
                rpn_cls_score = rpn_cls_score.reshape(-1, 2)
                scores = rpn_cls_score.softmax(dim=1)[:, 1]
            rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            anchors = mlvl_anchors[idx]
            if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:  # eg:cfg.nms_pre=2000
                _, topk_inds = scores.topk(cfg.nms_pre)   # 根据预测概率进行初步筛选
                rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
                anchors = anchors[topk_inds, :]
                scores = scores[topk_inds]
            proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
                                   self.target_stds, img_shape)  # 根据anchor box和rpn网络的预测偏移量得到实际的区域建议候选框
            if cfg.min_bbox_size > 0:  
                w = proposals[:, 2] - proposals[:, 0] + 1
                h = proposals[:, 3] - proposals[:, 1] + 1
                valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
                                           (h >= cfg.min_bbox_size)).squeeze()
                proposals = proposals[valid_inds, :]
                scores = scores[valid_inds]
            proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
            proposals, _ = nms(proposals, cfg.nms_thr)  # 通过NMS进行筛选 eg:cfg.nms_thr=0.7
            proposals = proposals[:cfg.nms_post, :]
            mlvl_proposals.append(proposals)
        proposals = torch.cat(mlvl_proposals, 0)
        if cfg.nms_across_levels:     
            proposals, _ = nms(proposals, cfg.nms_thr)
            proposals = proposals[:cfg.max_num, :]
        else:
            scores = proposals[:, 4]
            num = min(cfg.max_num, proposals.shape[0])   # eg:cfg.max_num=2000
            _, topk_inds = scores.topk(num)
            proposals = proposals[topk_inds, :]
        return proposals
  • 该部分即为获取区域建议候选框的过程,首先重新生成所有anchor box,并根据预测置信度进行初步筛选,将anchor box叠加rpn网络预测的坐标偏移量得到更为准确的候选框,再通过NMS(非最大值抑制)得到最终的区域建议候选框。
  • 由于源代码中RPN网络输入多个feaure map,所以更详细地分析会略微复杂一些,但原理都是相同的。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值