【代码】mmdetection框架

0.前言

这篇文章是使用mmdetection的一些记录,记录对于代码、设计理念的个人理解。

1.train

使用tools.train进行训练。添加如下代码来使用debug模式:

    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "4"
    args = ['./configs/cascade_mask_rcnn_r101_fpn_1x.py',
            '--gpus', '1',
            '--work_dir', 'cascade_mask_rcnn_r101_fpn_1x'
            ]

1.1. 首先是建立模型:

    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None):
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

其中,cfg是config文件,DETECTORS为全局对象,在models/registry中创建,是一个Registry对象。Registry类含_name和_module_dict属性,在一开始只将_name赋予’detector’等字符。在每个与检测器有关的类之前都有 @DETECTORS.register_module 修饰器,它可以将这个类以及其名字(_name_属性)在DETECTORS的_module_dict中。
build调用build_from_cfg,首先取出cfg建立的对象类型obj_type,使用get从注册器(Registry对象)中取出相应的类,使用inspect来判断取出的obj_type是否是类。之后使用obj_type(类)将args(就是cfg)作为参数进行实例化。

def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        registry (:obj:`Registry`): The registry to search the type from.
        default_args (dict, optional): Default initialization arguments.

    Returns:
        obj: The constructed object.
    """
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop('type')	# 对象的名字,比如CascadeRCNN
    if mmcv.is_str(obj_type):
        obj_type = registry.get(obj_type)
        if obj_type is None:
            raise KeyError('{} is not in the {} registry'.format(
                obj_type, registry.name))
    elif not inspect.isclass(obj_type):
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_type(**args)   #返回实例化

1.2. 建立train_dataset:

同样的套路,build调用build_from_cfg,按照cfg中的描述进行实例化,只是cfg是dataset的cfg。

1.3. 训练:

    train_detector(
        model,
        train_dataset,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)

train_detector调用non_dist_train,在这里将model并行化,建立data_loaderoptimizerrunner

1.3.1 建立data_loader:
##################_non_dist_train部分
    # prepare data loaders
    data_loaders = [
        build_dataloader(
            dataset,
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            cfg.gpus,
            dist=False)
    ]

##################build_dataloader函数
def build_dataloader(dataset,
                     imgs_per_gpu,
                     workers_per_gpu,
                     num_gpus=1,
                     dist=True,
                     **kwargs):
    shuffle = kwargs.get('shuffle', True)
    if dist:
        rank, world_size = get_dist_info()
        if shuffle:
            sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
                                              world_size, rank)
        else:
            sampler = DistributedSampler(
                dataset, world_size, rank, shuffle=False)
        batch_size = imgs_per_gpu
        num_workers = workers_per_gpu
    else:
        sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
        batch_size = num_gpus * imgs_per_gpu
        num_workers = num_gpus * workers_per_gpu

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
        pin_memory=False,
        **kwargs)

build_dataloader中主要是创建了两个对象samplercollate(通过偏函数partial来创建),前者是采样器,采样出下标,后者是整理器,用于组成batch输出。之后使用pytorch自带的DataLoader就行了。sampler考虑了并行操作。collate除了支持对于Sequence,Mapping的batch构建外,更重要的是有对于DataContainer类型数据的batch操作,这是一个mmdet中创建的新类型,支持多种数据类型。

1.3.3 建立runner:
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
                    cfg.log_level)
############batch_processor的定义
def batch_processor(model, data, train_mode):
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

    return outputs

其中batch_processor调用model来得到loss(model的forward得到的是loss而不是网络的输出)。之后对loss进行一些小处理。
runner的初始化基本上就是model, optimizer, work_dir等的初始化。

1.3.4 runner运行:
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
# 删除部分
    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        self._max_epochs = max_epochs
        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run')

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str): 
                    epoch_runner = getattr(self, mode)
                elif callable(mode):  # custom train()
                    epoch_runner = mode
                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= max_epochs:
                        return
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')

workflow代表的是工作流程E.g, [(‘train’, 2), (‘val’, 1)] ;run中通过getattr获得epoch_runner,一般就是runner.trainrunner.val,前者就是一般的train过程,首先self.model.train()来避免eval状态。之后就是一般的train了,里面有用到多处的call_hook

    def call_hook(self, fn_name):
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

call_hook的作用就是一个一个的hook的fn_name这个函数作用到自身,获得一些或者改变一些信息吧。

2. 思路

runner控制模型的训练、验证和测试过程。
dataloader负责数据的导入。
模型中anchor生成、anchor匹配等操作均隐藏在了model中,model又分为
与anchor有关的head:anchor_head
主干:backbones
ROI有关的head:bbox_heads
检测器本体:detectors
损失函数:losses
与mask有关的head:mask_heads
backbone进一步基础上的特征提取module:necks
attention机制等插件:plugins
roi提取器:roiextractors
不知道是啥:shared_heads
所有与技术细节有关的部分都放在了这些model当中。这些model也会调用core中的函数。

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值