(十一)mmdetection源码解读:train_detector

一、train_detector函数分析

train_detector函数定义

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
        'type']

下面是train_detector的主干,我删除了异常判断、版本兼容、分布式训练等内容,下面列出来的是我认为比较重要的部分。

1、DataLoader

是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,一般只要是用PyTorch来训练模型基本都会用到该接口,该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # `num_gpus` will be ignored if distributed
            num_gpus=len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed,
            runner_type=runner_type,
            persistent_workers=cfg.data.get('persistent_workers', False))
        for ds in dataset
    ]

2、optimizer

#构建优化器,optimizer目的:优化SGD,训练快速收敛并且保证准确率

   optimizer = build_optimizer(model, cfg.optimizer)

3、runner

build runner runner(实现在mmcv中)主要是用来管理模型训练时的生命周期,负责 OpenMMLab 中所有框架的训练过程调度,也就是管理何时执行resume、logger、save checkpoint、学习率更新、梯度计算BP等常见操作。

    runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))

4、register hooks

register hooks 注册多个hook,在训练过程中调用,学习率设置、优化器设置、模型保存、日志打印等。

    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))

5、runner.run

加载模型 # runner.run-> runner.train-> runner.run_iter->self.model.train_step,进行模型训练

    runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow)

二、train_detector参数说明

datasets:datasets = [build_dataset(cfg.data.train)]
model:model = build_detector(
cfg.model, train_cfg=cfg.get(‘train_cfg’), test_cfg=cfg.get(‘test_cfg’))
cfg:cfg = Config.fromfile(…)按照配置文件实例化得到
distributed:是否进行分布式训练
validate:训练过程中是否评估

train_detector(model, datasets, cfg, distributed=False, validate=True)
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值