目录
一、train_detector函数分析
train_detector函数定义
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
'type']
下面是train_detector的主干,我删除了异常判断、版本兼容、分布式训练等内容,下面列出来的是我认为比较重要的部分。
1、DataLoader
是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,一般只要是用PyTorch来训练模型基本都会用到该接口,该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# `num_gpus` will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
runner_type=runner_type,
persistent_workers=cfg.data.get('persistent_workers', False))
for ds in dataset
]
2、optimizer
#构建优化器,optimizer目的:优化SGD,训练快速收敛并且保证准确率
optimizer = build_optimizer(model, cfg.optimizer)
3、runner
build runner runner(实现在mmcv中)主要是用来管理模型训练时的生命周期,负责 OpenMMLab 中所有框架的训练过程调度,也就是管理何时执行resume、logger、save checkpoint、学习率更新、梯度计算BP等常见操作。
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
4、register hooks
register hooks 注册多个hook,在训练过程中调用,学习率设置、优化器设置、模型保存、日志打印等。
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
5、runner.run
加载模型 # runner.run-> runner.train-> runner.run_iter->self.model.train_step,进行模型训练
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
二、train_detector参数说明
datasets:datasets = [build_dataset(cfg.data.train)]
model:model = build_detector(
cfg.model, train_cfg=cfg.get(‘train_cfg’), test_cfg=cfg.get(‘test_cfg’))
cfg:cfg = Config.fromfile(…)按照配置文件实例化得到
distributed:是否进行分布式训练
validate:训练过程中是否评估
train_detector(model, datasets, cfg, distributed=False, validate=True)