mmdetection源码详细解读

简介

1. 测试代码

import mmcv
from mmdet.apis import init_detector, inference_detector, show_result_pyplot

config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:1')

# test a single image and show the results
img = 'demo.jpg'  # or img = mmcv.imread(img), which will only load it once
result = inference_detector(model, img)

# visualize the results in a new window
show_result_pyplot(model, img, result)

新建一个工程,其目录结构为:
在这里插入图片描述
实际上,我把mmdetection看成是第三方库(类比numpy、cv2等),所以我把下载下来的GitHub代码解压到了python的工程文件中,即
在这里插入图片描述
但我看有的人直接把mmdetection当做是工程文件也是没问题的,比如D2Det的官方代码就是如此:你可以发现其目录结构和mmdetection的源码结构是一样的。
在这里插入图片描述

2. mmdetection/mmdet/apis

该文件夹下有三个文件:

  • inference.py:用于初始化模型、前向推理、读取图片、显示检测结果等。
  • train:用于训练。
  • test:用于测试。
    在这里插入图片描述

2.1 mmdetection/mmdet/apis/inference.py

2.1.1 init_detector

def init_detector(config, checkpoint=None, device='cuda:0'):
    """Initialize a detector from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
    	# 如果config是配置文件路径,则用mmcv.Config.fromfile()读取,
    	# 并返回类mmcv.Config()实例化后的object
    	# 如果config已经是类mmcv.Config()实例化后的object,则不需要其他操作。
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    config.model.pretrained = None
    model = build_detector(config.model, test_cfg=config.test_cfg)
    if checkpoint is not None:
        map_loc = 'cpu' if device == 'cpu' else None
        checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
        if 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
        else:
            warnings.simplefilter('once')
            warnings.warn('Class names are not saved in the checkpoint\'s '
                          'meta data, use COCO classes by default.')
            model.CLASSES = get_classes('coco')
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model

函数作用、输入参数、输出参数直接见注解。这里给出几个例子:

from mmdet.apis import init_detector, inference_detector, show_result_pyplot

config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:0')
# device 可以是'cpu'、'cuda:0'、'cuda:1'等。

2.1.2 inference_detector

def inference_detector(model, img):
    """Inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
            images.

    Returns:
        If imgs is a str, a generator will be returned, otherwise return the
        detection results directly.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(img=img)
    data = test_pipeline(data)
    data = collate([data], samples_per_gpu=1)
    if next(model.parameters()).is_cuda:
        # scatter to specified GPU
        data = scatter(data, [device])[0]
    else:
        # Use torchvision ops for CPU mode instead
        for m in model.modules():
            if isinstance(m, (RoIPool, RoIAlign)):
                if not m.aligned:
                    # aligned=False is not implemented on CPU
                    # set use_torchvision on-the-fly
                    m.use_torchvision = True
        warnings.warn('We set use_torchvision=True in CPU mode.')
        # just get the actual data from DataContainer
        data['img_metas'] = data['img_metas'][0].data

    # forward the model
    with torch.no_grad():
        result = model(return_loss=False, rescale=True, **data)
    return result

3. build_detector()

model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

  • 根据配置文件构建神经网络

以configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py配置文件为例,cfg.modelcfg.train_cfgcfg.test_cfg均为字典类型,分别与配置文件中的内容相对应。

DETECTORS = Registry('detector')
def build_detector(cfg, train_cfg=None, test_cfg=None):
    """Build detector."""
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
  • Registry是一个大工厂,工厂中包含了很多小仓库,共有7个小仓库,这些仓库在train.py开始import 模块时就自动创建,并且每个仓库中都放了各种样的小商品,比如在detectors/__init__.py中可以看到,仓库DETECTORS的小商品有:。
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
# detectors/__init__.py
from .atss import ATSS
from .base import BaseDetector
from .cascade_rcnn import CascadeRCNN
from .fast_rcnn import FastRCNN
from .faster_rcnn import FasterRCNN
from .fcos import FCOS
from .fovea import FOVEA
from .fsaf import FSAF
from .gfl import GFL
from .grid_rcnn import GridRCNN
from .htc import HybridTaskCascade
from .mask_rcnn import MaskRCNN
from .mask_scoring_rcnn import MaskScoringRCNN
from .nasfcos import NASFCOS
from .point_rend import PointRend
from .reppoints_detector import RepPointsDetector
from .retinanet import RetinaNet
from .rpn import RPN
from .single_stage import SingleStageDetector
from .two_stage import TwoStageDetector
  • 此时参数cfg是字典类型,执行build_from_cfg()
def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
        modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg]
        return nn.Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)
  • 比如在fcos.py文件中,修饰器@DETECTORS.register_module()的作用就是将创建好的小商品FCOS放入仓库DETECTORS中。
@DETECTORS.register_module()
class FCOS(SingleStageDetector):
    """Implementation of `FCOS <https://arxiv.org/abs/1904.01355>`_"""

    def __init__(self,
                 backbone,
                 neck,
                 bbox_head,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
                                   test_cfg, pretrained)
  • register_module()函数中,你可以看到它的用法,小商品的名字就是在注册时指定的,而注册时的名字就是类名FCOSResNet等。所以你要创建自己的检测器,你需要构建一个类,把这个类当做小商品,然后将其放在相应的仓库中(注册)
    def register_module(self, name=None, force=False, module=None):
        """Register a module.

        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.

        Example:
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass

            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)
  • obj_type='FCOS',在
def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict."""
    args = cfg.copy()
    obj_type = args.pop('type') # 'FCOS'
    if is_str(obj_type):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(
                f'{obj_type} is not in the {registry.name} registry')
    elif inspect.isclass(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f'type must be a str or valid type, but got {type(obj_type)}')

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_cls(**args)
  • 6
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
mmdetection是一个基于PyTorch开发的目标检测框架,它提供了一系列模型和工具,用于实现目标检测任务。该框架的源码解读可以从其网络结构设计入手。 在mmdetection中,网络结构的设计是通过继承SingleStageDetector和TwoStageDetector来实现的。SingleStageDetector继承了backbone、neck和head,而TwoStageDetector继承了backbone、neck、rpn_head和roi_head。这种继承关系的设计使得模型的构建更加灵活和可扩展。 另外,mmdetection框架还将网络结构的设计细分为detectors、backbones、necks、dense_heads、roi_heads和seg_heads等组件。这些组件都有自己的base定义接口,并扩展了不同经典论文的结构,可以直接使用。这样的设计使得用户可以根据具体任务的需求选择合适的组件进行组合,从而实现更加精确和高效的目标检测。 总之,mmdetection框架的源码解读主要涉及网络结构的设计和组件的使用。通过深入理解这些内容,我们可以更好地理解和使用mmdetection框架来进行目标检测任务。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [mmdetection源码解析](https://blog.csdn.net/Cxiazaiyu/article/details/123995333)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值