mmdetection代码阅读系列(四):RepPoint代码阅读(中)AnchorFreeHead

Head

1. 简述

RepPointsHead的继承关系如下

RepPointsHead
AnchorFreeHead
BaseDenseHead
BaseModule
torch.nn.Module
BBoxTestMixin
object

对于BaseModule上面已经进行了介绍,这里不再赘述。
上面已经说到,对于Head要实现forward_train,simple_test,aug_test这三个函数。
前两个函数是在BaseDenseHead中声明的,aug_test是在AnchorFreeHead中声明的,但都未具体实现,simple_test和aug_test是在BBoxTestMixin中才提供了默认的实现,而forward_train则在BaseDenseHead分成了__call__(forward), loss, get_bboxes三个函数,这三个函数要在AnchorFreeHead甚至RepointsHead中才有具体的实现。下面先列举这三个父类Head的函数实现与调用图:

BaseDenseHead

BaseDenseHead
forward_train
__call__/forward
abs: loss/get_bboxes
simple_test
simple_test_bboxes

BBoxTestMixin

BBoxTestMixin
simple_test_bboxes
forward
get_bboxes
aug_test_bboxes
get_bboxes
merge_aug_bboxes
bbox_mapping_back
multiclass_nms
simple_test_rpn/aug_test_rpn

AnchorFreeHead

AnchorFreeHead
__init__
_init_layers
_init_cls_convs
_init_reg_convs
_init_predictor
forward
forward_single
cls_layer
conv_cls
reg_layer
conv_reg
abs: loss/get_bboxes/get_targets
get_points
aug_test
aug_test_bboxes

2. BaseDenseHead

BaseDenseHead将要实现的forward_train成了三个操作,即__call__(forward), loss, get_bboxes,但都没有具体实现,因此后面需要实现这三个函数。

BaseDenseHead
forward_train
__call__/forward
abs: loss/get_bboxes
simple_test
simple_test_bboxes
class BaseDenseHead(BaseModule, metaclass=ABCMeta):
    @abstractmethod
    def loss(self, **kwargs):
        """Compute losses of the head."""
        pass

    @abstractmethod
    def get_bboxes(self, **kwargs):
        """Transform network output for a batch into bbox predictions."""
        pass

    def forward_train(self,
                      x,
                      img_metas,
                      gt_bboxes,
                      gt_labels=None,
                      gt_bboxes_ignore=None,
                      proposal_cfg=None,
                      **kwargs):
        """
        Args:
            x (list[Tensor]): Features from FPN.
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            gt_bboxes (Tensor): Ground truth bboxes of the image,
                shape (num_gts, 4).
            gt_labels (Tensor): Ground truth labels of each box,
                shape (num_gts,).
            gt_bboxes_ignore (Tensor): Ground truth bboxes to be
                ignored, shape (num_ignored_gts, 4).
            proposal_cfg (mmcv.Config): Test / postprocessing configuration,
                if None, test_cfg would be used

        Returns:
            tuple:
                losses: (dict[str, Tensor]): A dictionary of loss components.
                proposal_list (list[Tensor]): Proposals of each image.
        """
        outs = self(x)
        if gt_labels is None:
            loss_inputs = outs + (gt_bboxes, img_metas)
        else:
            loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
        losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        if proposal_cfg is None:
            return losses
        else:
            proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
            return losses, proposal_list

    def simple_test(self, feats, img_metas, rescale=False):
        """Test function without test-time augmentation.

        Args:
            feats (tuple[torch.Tensor]): Multi-level features from the
                upstream network, each is a 4D-tensor.
            img_metas (list[dict]): List of image information.
            rescale (bool, optional): Whether to rescale the results.
                Defaults to False.

        Returns:
            list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
                The first item is ``bboxes`` with shape (n, 5),
                where 5 represent (tl_x, tl_y, br_x, br_y, score).
                The shape of the second tensor in the tuple is ``labels``
                with shape (n,)
        """
        return self.simple_test_bboxes(feats, img_metas, rescale=rescale)

3. BBoxTestMixin

在BaseDenseHead中声明的simple_test以及在AnchorFreeHead中声明的aug_test,都是调用simple_test_bboxes和aug_test_bboxes来实现的,对于OneStage的模型,只需关系这两个函数就可以了。但本质上其核心功能forward和get_bboxes还是交给了AnchorFreeHead和RepPointsHead来实现的。因此也已认为只是封装了test时的通用代码壳。

BBoxTestMixin
simple_test_bboxes
forward
get_bboxes
aug_test_bboxes
get_bboxes
merge_aug_bboxes
bbox_mapping_back
multiclass_nms
simple_test_rpn/aug_test_rpn
self, feats, img_metas, rescale=False

4. AnchorFreeHead

AnchorFreeHead是第一个开始有运算Layer(Conv, FC)的Head,可以认为它有两个作用:

  • 创建计算层,即_init_layer函数;
  • 计算运算层结果,即forward;
  • 剩下loss和get_bboxes还没有进行具体的实现。
AnchorFreeHead
__init__
_init_layers
_init_cls_convs
_init_reg_convs
_init_predictor
forward
forward_single
cls_layer
conv_cls
reg_layer
conv_reg
abs: loss/get_bboxes/get_targets
get_points
aug_test
aug_test_bboxes
@HEADS.register_module()
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
    """Anchor-free head (FCOS, Fovea, RepPoints, etc.)."""

    _version = 1

    def __init__(self,
                 num_classes,
                 in_channels,
                 feat_channels=256,
                 stacked_convs=4,
                 strides=(4, 8, 16, 32, 64),
                 dcn_on_last_conv=False,
                 conv_bias='auto',
                 loss_cls=dict(
                     type='FocalLoss',
                     use_sigmoid=True,
                     gamma=2.0,
                     alpha=0.25,
                     loss_weight=1.0),
                 loss_bbox=dict(type='IoULoss', loss_weight=1.0),
                 conv_cfg=None,
                 norm_cfg=None,
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=dict(
                     type='Normal',
                     layer='Conv2d',
                     std=0.01,
                     override=dict(
                         type='Normal',
                         name='conv_cls',
                         std=0.01,
                         bias_prob=0.01))):
		...
        self._init_layers()

    def _init_layers(self):
        """Initialize layers of the head."""
        self._init_cls_convs()
        self._init_reg_convs()
        self._init_predictor()

    def _init_cls_convs(self):
        """Initialize classification conv layers of the head."""
        self.cls_convs = nn.ModuleList()
        for i in range(self.stacked_convs):
            chn = self.in_channels if i == 0 else self.feat_channels
            if self.dcn_on_last_conv and i == self.stacked_convs - 1:
                conv_cfg = dict(type='DCNv2')
            else:
                conv_cfg = self.conv_cfg
            self.cls_convs.append(
                ConvModule(
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=conv_cfg,
                    norm_cfg=self.norm_cfg,
                    bias=self.conv_bias))

    def _init_reg_convs(self):
        """Initialize bbox regression conv layers of the head."""
        self.reg_convs = nn.ModuleList()
        for i in range(self.stacked_convs):
            chn = self.in_channels if i == 0 else self.feat_channels
            if self.dcn_on_last_conv and i == self.stacked_convs - 1:
                conv_cfg = dict(type='DCNv2')
            else:
                conv_cfg = self.conv_cfg
            self.reg_convs.append(
                ConvModule(
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=conv_cfg,
                    norm_cfg=self.norm_cfg,
                    bias=self.conv_bias))

    def _init_predictor(self):
        """Initialize predictor layers of the head."""
        self.conv_cls = nn.Conv2d(
            self.feat_channels, self.cls_out_channels, 3, padding=1)
        self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        ...

    def forward(self, feats):
        """Forward features from the upstream network.
        Returns:
            tuple: Usually contain classification scores and bbox predictions.
                cls_scores (list[Tensor]): num_points * num_classes.
                bbox_preds (list[Tensor]): num_points * 4.
        """
        return multi_apply(self.forward_single, feats)[:2]

    def forward_single(self, x):
        """Forward features of a single scale level."""
        cls_feat = x
        reg_feat = x

        for cls_layer in self.cls_convs:
            cls_feat = cls_layer(cls_feat)
        cls_score = self.conv_cls(cls_feat)

        for reg_layer in self.reg_convs:
            reg_feat = reg_layer(reg_feat)
        bbox_pred = self.conv_reg(reg_feat)
        return cls_score, bbox_pred, cls_feat, reg_feat

    @abstractmethod
    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
    def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             gt_bboxes_ignore=None):
        raise NotImplementedError

    @abstractmethod
    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
    def get_bboxes(self,
                   cls_scores,
                   bbox_preds,
                   img_metas,
                   cfg=None,
                   rescale=None):
        """Transform network output for a batch into bbox predictions."""
        raise NotImplementedError

    @abstractmethod
    def get_targets(self, points, gt_bboxes_list, gt_labels_list):
        """Compute regression, classification and centerness targets for points in multiple images."""
        raise NotImplementedError

    def _get_points_single(self,
                           featmap_size,
                           stride,
                           dtype,
                           device,
                           flatten=False):
        """Get points of a single scale level."""
        h, w = featmap_size
        # First create Range with the default dtype, than convert to
        # target `dtype` for onnx exporting.
        x_range = torch.arange(w, device=device).to(dtype)
        y_range = torch.arange(h, device=device).to(dtype)
        y, x = torch.meshgrid(y_range, x_range)
        if flatten:
            y = y.flatten()
            x = x.flatten()
        return y, x

    def get_points(self, featmap_sizes, dtype, device, flatten=False):
        """Get points according to feature map sizes."""
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                self._get_points_single(featmap_sizes[i], self.strides[i],
                                        dtype, device, flatten))
        return mlvl_points

    def aug_test(self, feats, img_metas, rescale=False):
        return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值