【mmdetection源码解读(二)】RPN网络

以下仅为个人理解,若有不正之处还请指出,欢迎交流!
一、整体过程
  • mmdet/models/detectors/two_stage.py中的部分代码
        if self.with_rpn:  # rpn_head
            rpn_outs = self.rpn_head(x)   # x为feature map
            rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                          self.train_cfg.rpn)
            rpn_losses = self.rpn_head.loss(
                *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
            losses.update(rpn_losses)
#---------------------------分割线---------------------------------------
            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)  
            proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
            proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
        else:
            proposal_list = proposals 
二、详细解读
1.RPN(rpn_head)网络结构
  • 网络结构示意图如下图所示:
    在这里插入图片描述
  • 代码实现
class RPNHead(AnchorHead):  # 继承AnchorHead类
    def __init__(self, in_channels, **kwargs):
        super(RPNHead, self).__init__(2, in_channels, **kwargs)   # rpn只是实现二分类,2-->num_classes
                     
    def _init_layers(self):
        self.rpn_conv = nn.Conv2d(
            self.in_channels, self.feat_channels, 3, padding=1)  # feature map首先经过一个3×3卷积层,特征图尺寸不变
        self.rpn_cls = nn.Conv2d(self.feat_channels,  
                                 self.num_anchors * self.cls_out_channels, 1)   # 1×1卷积层 由于rpn是二分类,所以cls_out_channels=1即可
        # self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
        self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)   # 1×1卷积层  4对应4个坐标偏移量

    def init_weights(self):
        normal_init(self.rpn_conv, std=0.01)
        normal_init(self.rpn_cls, std=0.01)
        normal_init(self.rpn_reg, std=0.01)

    def forward_single(self, x):   # 结构与示意图对应
        x = self.rpn_conv(x)
        x = F.relu(x, inplace=True)
        rpn_cls_score = self.rpn_cls(x) 
        rpn_bbox_pred = self.rpn_reg(x)
        return rpn_cls_score, rpn_bbox_pred  # anchor box
2.RPN_Loss
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
  • rpn_head.py中
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        losses = super(RPNHead, self).loss(    # 用于调用父类的一个方法,这里的父类为anchor_head
            cls_scores,
            bbox_preds,
            gt_bboxes,
            None,
            img_metas,
            cfg,
            gt_bboxes_ignore=gt_bboxes_ignore)
        return dict(
            loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
  • anchor_head.py中
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,    # None
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] # 特征图大小
        assert len(featmap_sizes) == len(self.anchor_generators)
        device = cls_scores[0].device
        
        anchor_list, valid_flag_list = self.get_anchors(   # 获取所有anchor box
            featmap_sizes, img_metas, device=device)
        label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1  
        cls_reg_targets = anchor_target(   # 对anchor区分正负样本并采样用于rpn训练,返回真实标签和坐标偏移量等信息
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            self.target_means,
            self.target_stds,
            cfg,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=label_channels,
            sampling=self.sampling)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets
        num_total_samples = (
            num_total_pos + num_total_neg if self.sampling else num_total_pos)
        losses_cls, losses_bbox = multi_apply(
            self.loss_single,   # 见下方
            cls_scores,
            bbox_preds,
            labels_list,
            label_weights_list,
            bbox_targets_list,
            bbox_weights_list,
            num_total_samples=num_total_samples,
            cfg=cfg)
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
    def loss_single(self, cls_score, bbox_pred, labels, label_weights,   # 这里都是anchor box的相关信息,cls_score和bbox_pred为rpn网络的预测值
                    bbox_targets, bbox_weights, num_total_samples, cfg):
        # classification loss
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)
        cls_score = cls_score.permute(0, 2, 3,   
                                      1).reshape(-1, self.cls_out_channels)
        loss_cls = self.loss_cls(  
            cls_score, labels, label_weights, avg_factor=num_total_samples)
        # regression loss
        bbox_targets = bbox_targets.reshape(-1, 4)
        bbox_weights = bbox_weights.reshape(-1, 4)
        bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
        loss_bbox = self.loss_bbox(
            bbox_pred,
            bbox_targets,
            bbox_weights,
            avg_factor=num_total_samples)
        return loss_cls, loss_bbox    # 分别计算得到分类和回归loss
  • 更多细节不再展开,概括地讲,RPN网络会根据设定的尺度和纵横比生成大量anchor box(anchor box介绍可参考理解anchor box究竟是如何生成的),通过anchor box与gt box之间的IoU值来判定anchor box的真实标签,并进行采样,例如一个mini-batch选择256个anchor,设定正负样本比例1:1,再计算标记为正样本的anchor box相对于gt box的真实坐标偏移量。然后通过训练使loss收敛,以使RPN网络的预测结果更接近真实结果。

3.生成区域建议候选框(proposals)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
  • anchor_head.py中
    def get_bboxes(self,
                   cls_scores,
                   bbox_preds,
                   img_metas,
                   cfg,
                   rescale=False):
       
        assert len(cls_scores) == len(bbox_preds)
        num_levels = len(cls_scores) 
        device = cls_scores[0].device
        mlvl_anchors = [       # 需再次获取所有的anchor
            self.anchor_generators[i].grid_anchors(
                cls_scores[i].size()[-2:],   
                self.anchor_strides[i],
                device=device) for i in range(num_levels)
        ]
        result_list = []
        for img_id in range(len(img_metas)):
            cls_score_list = [
                cls_scores[i][img_id].detach() for i in range(num_levels)   
            ]
            bbox_pred_list = [
                bbox_preds[i][img_id].detach() for i in range(num_levels)
            ]
            img_shape = img_metas[img_id]['img_shape']
            scale_factor = img_metas[img_id]['scale_factor']
            proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,  # 获取区域建议候选框,具体见下方
                                               mlvl_anchors, img_shape,
                                               scale_factor, cfg, rescale)
            result_list.append(proposals)
        return result_list
  • rpn_head.py中
    def get_bboxes_single(self,
                          cls_scores,
                          bbox_preds,
                          mlvl_anchors,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False):
        mlvl_proposals = []
        for idx in range(len(cls_scores)):
            rpn_cls_score = cls_scores[idx]
            rpn_bbox_pred = bbox_preds[idx]
            assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
            rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
            if self.use_sigmoid_cls:
                rpn_cls_score = rpn_cls_score.reshape(-1)
                scores = rpn_cls_score.sigmoid()   # 二分类
            else:
                rpn_cls_score = rpn_cls_score.reshape(-1, 2)
                scores = rpn_cls_score.softmax(dim=1)[:, 1]
            rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            anchors = mlvl_anchors[idx]
            if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:  # eg:cfg.nms_pre=2000
                _, topk_inds = scores.topk(cfg.nms_pre)   # 根据预测概率进行初步筛选
                rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
                anchors = anchors[topk_inds, :]
                scores = scores[topk_inds]
            proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
                                   self.target_stds, img_shape)  # 根据anchor box和rpn网络的预测偏移量得到实际的区域建议候选框
            if cfg.min_bbox_size > 0:  
                w = proposals[:, 2] - proposals[:, 0] + 1
                h = proposals[:, 3] - proposals[:, 1] + 1
                valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
                                           (h >= cfg.min_bbox_size)).squeeze()
                proposals = proposals[valid_inds, :]
                scores = scores[valid_inds]
            proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
            proposals, _ = nms(proposals, cfg.nms_thr)  # 通过NMS进行筛选 eg:cfg.nms_thr=0.7
            proposals = proposals[:cfg.nms_post, :]
            mlvl_proposals.append(proposals)
        proposals = torch.cat(mlvl_proposals, 0)
        if cfg.nms_across_levels:     
            proposals, _ = nms(proposals, cfg.nms_thr)
            proposals = proposals[:cfg.max_num, :]
        else:
            scores = proposals[:, 4]
            num = min(cfg.max_num, proposals.shape[0])   # eg:cfg.max_num=2000
            _, topk_inds = scores.topk(num)
            proposals = proposals[topk_inds, :]
        return proposals
  • 该部分即为获取区域建议候选框的过程,首先重新生成所有anchor box,并根据预测置信度进行初步筛选,将anchor box叠加rpn网络预测的坐标偏移量得到更为准确的候选框,再通过NMS(非最大值抑制)得到最终的区域建议候选框。
  • 由于源代码中RPN网络输入多个feaure map,所以更详细地分析会略微复杂一些,但原理都是相同的。
  • 2
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
mmdetection是一个基于PyTorch开发的目标检测框架,它提供了一系列模型和工具,用于实现目标检测任务。该框架的源码解读可以从其网络结构设计入手。 在mmdetection中,网络结构的设计是通过继承SingleStageDetector和TwoStageDetector来实现的。SingleStageDetector继承了backbone、neck和head,而TwoStageDetector继承了backbone、neck、rpn_head和roi_head。这种继承关系的设计使得模型的构建更加灵活和可扩展。 另外,mmdetection框架还将网络结构的设计细分为detectors、backbones、necks、dense_heads、roi_heads和seg_heads等组件。这些组件都有自己的base定义接口,并扩展了不同经典论文的结构,可以直接使用。这样的设计使得用户可以根据具体任务的需求选择合适的组件进行组合,从而实现更加精确和高效的目标检测。 总之,mmdetection框架的源码解读主要涉及网络结构的设计和组件的使用。通过深入理解这些内容,我们可以更好地理解和使用mmdetection框架来进行目标检测任务。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [mmdetection源码解析](https://blog.csdn.net/Cxiazaiyu/article/details/123995333)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值