mmdetection v1.2源码阅读笔记二:网络前向传播过程,以Faster RCNN+FPN为例

1. 概述

  • 这篇文章主要探讨mmdetection前向传播计算误差的过程,以Faster RCNN+FPN为例

2. 源码讲解

  • 2.1 从tools/train.py说起,该函数中使用下述语句创建模型。Pytorch前向传播过程中会调用模型中的forward函数,我们主要看看调用forward函数的流程
model = build_detector(
    cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
  • 2.2 Faster RCNN检测框架的核心是个FasterRCNN类,类代码定义于mmdet/models/detectors/faster_rcnn.py,其父类为TwoStageDetectorFasterRCNN类代码定义如下
class FasterRCNN(TwoStageDetector):

    def __init__(self):
    	# some code
  • 2.3 TwoStageDetector类定义如下,其父类为BaseDetector, RPNTestMixin, BBoxTestMixin, MaskTestMixin,网络框架在TwoStageDetector类中创建
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       MaskTestMixin):
  • 2.4 进行前向传播时需要调用forward函数,可是TwoStageDetector类代码定义中并没有该函数,这是怎么回事呢,这里涉及到类的继承概念,调用类的函数时,会自类开始向其父类寻找,称之为类树的爬升。这里在TwoStageDetector的父类BaseDetector中找到forward函数,需要注意,这里找到的forward函数被TwoStageDetector类继承,即forward函数成为类TwoStageDetector的函数。 forward函数定义如下,可以看到,forward函数又调用forward_train函数,注意,由于forward函数作为类TwoStageDetector的子函数,所调用的forward_train函数也是类TwoStageDetector的子函数,forward_train函数定义如下,可以看到,forward_train实际上完成前向传播计算误差的过程,代码已粗略标注
def forward(self, img, img_meta, return_loss=True, **kwargs):
     return_loss:
        return self.forward_train(img, img_meta, **kwargs)
    else:
        return self.forward_test(img, img_meta, **kwargs)
def forward_train(self,
                  img,
                  img_meta,
                  gt_bboxes,
                  gt_labels,
                  gt_bboxes_ignore=None,
                  gt_masks=None,
                  proposals=None):
    """
    Args:
        img (Tensor): of shape (N, C, H, W) encoding input images.
            Typically these should be mean centered and std scaled.

        img_meta (list[dict]): list of image info dict where each dict has:
            'img_shape', 'scale_factor', 'flip', and may also contain
            'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
            For details on the values of these keys see
            `mmdet/datasets/pipelines/formatting.py:Collect`.

        gt_bboxes (list[Tensor]): each item are the truth boxes for each
            image in [tl_x, tl_y, br_x, br_y] format.

        gt_labels (list[Tensor]): class indices corresponding to each box

        gt_bboxes_ignore (None | list[Tensor]): specify which bounding
            boxes can be ignored when computing the loss.

        gt_masks (None | Tensor) : true segmentation masks for each box
            used if the architecture supports a segmentation task.

        proposals : override rpn proposals with custom proposals. Use when
            `with_rpn` is False.

    Returns:
        dict[str, Tensor]: a dictionary of loss components
    """
    x = self.extract_feat(img)  # img -> |backbone+neck| -> base_feature

    losses = dict()

    # RPN forward and loss,RPN前向传播部分,生成proposal
    if self.with_rpn:
        rpn_outs = self.rpn_head(x)  # rpn_outs = ((rpn_cls_score1, rpn_cls_score2, ..., rpn_cls_score5),
                                     # (rpn_bbox_pred1, rpn_bbox_pred2, ..., rpn_bbox_pred5))
                                     # shape: 图片数量*fpn层数*channel*H*W, 图片数量*fpn层数*channel*H*W
        rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                      self.train_cfg.rpn) # 用于计算rpn loss的输入
        rpn_losses = self.rpn_head.loss(
            *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) # 计算loss
        losses.update(rpn_losses) # 更新loss

        proposal_cfg = self.train_cfg.get('rpn_proposal',
                                          self.test_cfg.rpn)
        proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
        proposal_list = self.rpn_head.get_bboxes(*proposal_inputs) # rpn输出的proposal
    else:
        proposal_list = proposals

    # assign gts and sample proposals,将proposal与ground truth匹配
    if self.with_bbox or self.with_mask:
        bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
        bbox_sampler = build_sampler(
            self.train_cfg.rcnn.sampler, context=self)
        num_imgs = img.size(0)
        if gt_bboxes_ignore is None:
            gt_bboxes_ignore = [None for _ in range(num_imgs)]
        sampling_results = []
        for i in range(num_imgs):
            assign_result = bbox_assigner.assign(proposal_list[i],
                                                 gt_bboxes[i],
                                                 gt_bboxes_ignore[i],
                                                 gt_labels[i])
            sampling_result = bbox_sampler.sample(
                assign_result,
                proposal_list[i],
                gt_bboxes[i],
                gt_labels[i],
                feats=[lvl_feat[i][None] for lvl_feat in x])
            sampling_results.append(sampling_result)

    # bbox head forward and loss,对proposal特征进行池化操作,经过分类分支和回归分支后分别计算分类误差和回归误差
    if self.with_bbox:
        rois = bbox2roi([res.bboxes for res in sampling_results]) # 为每个proposal添加图片索引号
        # TODO: a more flexible way to decide which feature maps to use
        bbox_feats = self.bbox_roi_extractor(
            x[:self.bbox_roi_extractor.num_inputs], rois) # rois池化后的特征
        if self.with_shared_head:
            bbox_feats = self.shared_head(bbox_feats)
        cls_score, bbox_pred = self.bbox_head(bbox_feats)

        bbox_targets = self.bbox_head.get_target(sampling_results,
                                                 gt_bboxes, gt_labels,
                                                 self.train_cfg.rcnn)
        loss_bbox = self.bbox_head.loss(cls_score, bbox_pred,
                                        *bbox_targets)
        losses.update(loss_bbox)

    # mask head forward and loss
    if self.with_mask:
        if not self.share_roi_extractor:
            pos_rois = bbox2roi(
                [res.pos_bboxes for res in sampling_results])
            mask_feats = self.mask_roi_extractor(
                x[:self.mask_roi_extractor.num_inputs], pos_rois)
            if self.with_shared_head:
                mask_feats = self.shared_head(mask_feats)
        else:
            pos_inds = []
            device = bbox_feats.device
            for res in sampling_results:
                pos_inds.append(
                    torch.ones(
                        res.pos_bboxes.shape[0],
                        device=device,
                        dtype=torch.uint8))
                pos_inds.append(
                    torch.zeros(
                        res.neg_bboxes.shape[0],
                        device=device,
                        dtype=torch.uint8))
            pos_inds = torch.cat(pos_inds)
            mask_feats = bbox_feats[pos_inds]

        if mask_feats.shape[0] > 0:
            mask_pred = self.mask_head(mask_feats)
            mask_targets = self.mask_head.get_target(
                sampling_results, gt_masks, self.train_cfg.rcnn)
            pos_labels = torch.cat(
                [res.pos_gt_labels for res in sampling_results])
            loss_mask = self.mask_head.loss(mask_pred, mask_targets,
                                            pos_labels)
            losses.update(loss_mask)

    return losses
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值