目标检测00-09:mmdetection(Foveabox为例)-源码无死角解析(2)-模型构建总览

以下链接是个人关于mmdetection(Foveabox-目标检测框架)所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号 海量资源。 \color{blue}{ 海量资源}。 海量资源
目标检测00-00:mmdetection(Foveabox为例)-目录-史上最新无死角讲解

前言

    从前面的博客,我们已经知道了数据读取,数据增强,以及训练架构等等。那么接下来,久要深入的了解 Foveabox 这个网络了。那么这篇博客。我们就来讲解一下其训练的流程吧。主要的相关代码位于 mmdet/models 文件夹(后续默认都以该文件夹为主-如果没有特别提示)。在了解模型是如何构件之前,我们先查看项目根目录 configs\foveabox\my_fovea_r50_fpn_4x4_2x_coco.py,该文件为本人自己编写,可以通过如下链接复制:目标检测00-04:mmdetection(Foveabox为例)-config文件注释-持续更新。找到其中的 model 字典,内容如下:

# model settings
model = dict(
    type='FOVEA', # 设置为FOVEA,则其最终会调用到类mmdet.models.detectors.fovea.FOVEA
	backbone=dict( # 主干网络相关配置
		type='ResNet', # 主干网络的类型
		......)
	neck=dict(
		type='FPN', # FPN,特征金字塔
		......)
    bbox_head=dict(
    	type='FoveaHead', # 头部网络类型
    	......
    	loss_cls=dict( # FocalLoss 的相关配置
			type='FocalLoss'......)
		loss_bbox=dict(
		    type='SmoothL1Loss', 
			.......)
)

总的来说,一个网络的构件,主要包含了三个部分,分别为 backbone(主干网络),neck(衔接网络),bbox_head(头部网络)。下面我们就来看看 Foveabox 究竟是如何构件的吧。

模型总体构建

根据上面的 model 字典结构,其首先执行的是 type=‘FOVEA’,对应 mmdet/models/detectors/fovea.py 中的 FOVEA。其代码十分的简单,就是继承于 SingleStageDetector(单阶段目标检测),然后调用其父类的初始化函数,那么我们就来看看 SingleStageDetector 的流程吧。本人注释代码如下:

@DETECTORS.register_module()
class SingleStageDetector(BaseDetector):
    """Base class for single-stage detectors.

    Single-stage detectors directly and densely predict bounding boxes on the
    output features of the backbone+neck.
    """

    def __init__(self,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SingleStageDetector, self).__init__()
        # 根据主干网网络参数构建主干网络
        self.backbone = build_backbone(backbone)
        # 如果设置了衔接网络,则构件衔接网络
        if neck is not None:
            self.neck = build_neck(neck)
        # 对头部网络的参数进行更新(训练和推理的配置参数有些不一样)
        bbox_head.update(train_cfg=train_cfg)
        bbox_head.update(test_cfg=test_cfg)
        # 根据配置参数构件头部网络
        self.bbox_head = build_head(bbox_head)
        # 赋值训练以及测试参数
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        # 如果设置了预训练模型,则加载预训练模型。
        self.init_weights(pretrained=pretrained)


    def init_weights(self, pretrained=None):
        """Initialize the weights in detector.
            权重初始化
        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        super(SingleStageDetector, self).init_weights(pretrained)
        self.backbone.init_weights(pretrained=pretrained)
        if self.with_neck:
            if isinstance(self.neck, nn.Sequential):
                for m in self.neck:
                    m.init_weights()
            else:
                self.neck.init_weights()
        self.bbox_head.init_weights()

    def extract_feat(self, img):
        """Directly extract features from the backbone+neck.
           使用主干网络提取特征,如果存在衔接网络,则同时采用衔接网络
        """
        x = self.backbone(img)
        if self.with_neck:
            x = self.neck(x)
        return x

    def forward_dummy(self, img):
        """Used for computing network flops.
        See `mmdetection/tools/get_flops.py`
        用于计算网络的浮点型大小,通过 mmdetection/tools/get_flops.py 可以看到具体介绍
        """
        x = self.extract_feat(img)
        outs = self.bbox_head(x)
        return outs

    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None):
        """
        Args:
            img (Tensor): Input images of shape (N, C, H, W).
                Typically these should be mean centered and std scaled.
            img_metas (list[dict]): A List of image info dict where each dict
                has: 'img_shape', 'scale_factor', 'flip', and may also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                :class:`mmdet.datasets.pipelines.Collect`.
            gt_bboxes (list[Tensor]): Each item are the truth boxes for each
                image in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): Class indices corresponding to each box
            gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
                boxes can be ignored when computing the loss.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        # 提取图像的金字塔特征,如果输入图像大小为[640,480],batch_size=b,那么其输出为存在五个元素的一个列表,
        # x = [(b,256,60,80), (b,256,30,40), (b,256,15,20), (b,256,8,10), (b,256,4,5)]
        x = self.extract_feat(img)
        # 如果是进行训练,则调用头部网络的forward_train,获得loss返回,同于反向传播
        losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                              gt_labels, gt_bboxes_ignore)
        return losses

    def simple_test(self, img, img_metas, rescale=False):
        """Test function without test time augmentation.
        单张图片进行测试,测试功能不带有数据增强
        Args:
            imgs (list[torch.Tensor]): List of multiple images
            img_metas (list[dict]): List of image information.
            rescale (bool, optional): Whether to rescale the results.
                Defaults to False.
        Returns:
            np.ndarray: proposals
        """
        # 提出图像金字塔特征
        x = self.extract_feat(img)
        # 通过头部网络获得论文中的(tx1, ty1, tx2, ty2),以及其对应的类别概率, 但是大家要注意,这里出来是特征金字塔,如下(假设输入图像大小为640x480):
        # box的偏移值: outs[1]=[(1,4,60,80), (1,4,30,40), (1,4,15,20), (1,4,8,10), (1,4,4,5)]
        # box对应的类别概率 outs[0]=[(1,num_class,60,80), (1,num_class,30,40), (1,num_class,15,20), (1,num_class,8,10), (1,num_class,4,5)]
        outs = self.bbox_head(x)

        # 把(tx1, ty1, tx2, ty2) 转换为对应的 box 坐标,同时进行了 nms 处理。
        # img_metas 主要记录了输入图像的路径, 原始大小,当前大小,缩放因子,正则化参数。
        # 返回的 bbox_list 包含了两个元素,第一个元素存储所有 box 坐标以及概率值,第二元素存储 box 对应的类别。
        bbox_list = self.bbox_head.get_bboxes(
            *outs, img_metas, rescale=rescale)
        # skip post-processing when exporting to ONNX,如果导出为 ONNX,则跳过后期处理
        if torch.onnx.is_in_onnx_export():
            return bbox_list

        # 把结果转化成numpy数组的形式。
        bbox_results = [
            bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
            for det_bboxes, det_labels in bbox_list
        ]
        return bbox_results[0]

    def aug_test(self, imgs, img_metas, rescale=False):
        """Test function with test time augmentation."""
        raise NotImplementedError

代码领读

其实代码还是十分简单的,主要核心部分为 _init_ 函数的如下部分:

        # 根据主干网网络参数构建主干网络
        self.backbone = build_backbone(backbone)
        # 如果设置了衔接网络,则构件衔接网络
        if neck is not None:
            self.neck = build_neck(neck)
        # 根据配置参数构件头部网络
        self.bbox_head = build_head(bbox_head)

其构件的过程,跟前面配置文件中的 model 字典是一致的,如下:

# model settings
model = dict(
    type='FOVEA', # 设置为FOVEA,则其最终会调用到类mmdet.models.detectors.fovea.FOVEA
	backbone=dict( # 主干网络相关配置
		type='ResNet', # 主干网络的类型
		......)
	neck=dict(
		type='FPN', # FPN,特征金字塔
		......)
    bbox_head=dict(
    	type='FoveaHead', # 头部网络类型
    	......)
)    	

ResNet 位于代码 mmdet\models\backbones\resnet.py 之中,FPN 位于 mmdet\models\necks\fpn.py 之中。经过 self.neck 网络输出的最终结果为一个列表,形状如下如下(假设网络输入图片大小为640x480):

 [(b,256,60,80), (b,256,30,40), (b,256,15,20), (b,256,8,10), (b,256,4,5)]

拿到这个结果之后,其会送入到 self.bbox_head 进行出来。这里就是我们下篇博客的重点了,同时也是这片论文的重点。对于 self.backbone 以及 self.neck 的具体过程,本人就不进行讲解了,因为这些并不是 Foveabox 这片论文的重点。有兴趣的朋友可以自己分析一下源码。

在这里插入图片描述

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

江南才尽,年少无知!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值