runner
是用于管理整个(训练+验证)、测试流程的类。
mmfewshot 中的runners
根据任务需求,大致划分为EpochBasedRunner
IterBasedRunner
,两者均继承自BaseRunner
。
1 BaseRunner
BaseRunner
实现了初始化方法,预定义了如下抽象方法,需要在子类中实现:
run()
train()
val()
save_checkpoint()
初始化方法接收 model
、optimizer
、logger
,并检验各个对象的类别,将他们保存为自己的类成员。model
可以 为wrapper
或模型本身。且model
必须具备train_setp
方法。
BaseRunner
实现了很多 hook
注册方法,比如register_hook(self)
。该方法接收一个hook
对象,并根据其 priorty
插入到self._hooks
中。
提供了很多方法,用以访问内部属性情况,比较重要的有get_hook_info(self)
,返回runner
在不同周期内部调用hook
的顺序。
2 EpochBasedRunner
继承自BaseRunner
,除开run()
train()
val()
save_checkpoint()
外,为满足 resume 的需求,也实现了相关方法。且实现了train_iter()
方法,该方法在train()
中循环调用,用以处理单个batch
。
2.1 EpochBasedRunner在框架中的调用
以train.py
为例。
- parse args;生成
cfg dict
类对象;构建任务路径;初始化model
;初始化train_dataset
,或有val_dataset
。 - 调用
train_model
APItrain_model( model, datasets, cfg, distributed=distributed, validate=(not args.no_validate), timestamp=timestamp, device='cpu' if args.device == 'cpu' else 'cuda', meta=meta)
within train_model
:
- 构建
train_dataloader
- 根据cfg,用
MMDataParallel
/MMDistributedDataParallel
wrapmoel
。 - 构建
optimizer
- 实例化
runner
runner = build_runner( cfg.runner, default_args=dict( model=model, batch_processor=None, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta))
- 注册
training hooks
,runner.register_training_hooks()
,该方法接收 cfg 字典 - 如果需要validate,build
val_dataset
。val_dataset
一般是MetaTestDataset(EpisodicDataset)
类,其内部的__getitem__
方法根据self._mode
有不同的返回值。self._mode
有3种设定值,test_set
、support
、query
。设置为test_set
,dataset
返回单个的数据样本,这个mode是用于遍历所有的测试样本,提取feature并保存,用以加速后续的测试环节。 build_meta_test_dataloader(val_dataset, meta_test_cfg)
构建validate dataloader。不管在pretraining
、meta training
阶段,validate
都是采用episodic
的方式,这与meta testing
的数据pipeline是相同的,不用过分纠结函数命名问题。meta test
就是指episodic validate
。注意,valdataset
的self._mode
有三种模式,比起构建一个dataloader,频繁的切换self._mode
,直接采用深度拷贝构建3中不同模式的dataloader在验证阶段使用更为方便。一般默认使用episodic
的方式进行validate,所以val_datasets
都是EpsodicDataset
的子类。- 实例化
MetaTestEvalHook
。并将其注册入runner._hooks
。MetaTestEvalHook
用于模型验证,在初始化阶段接收上述3个 dataloade,以及validate cfg。MetaTestEvalHook
实现了若干hook
方法,较为重要的是after_train_epoch
、after_train_iter
。此外,实现了evaluate
与_save_ckpt
,在model
进行一个epoch
/batch
的forward
后,会分别调用这两个hook
。根据 validate cfg 的差异,MetaTestEvalHook
会在after_train_epoch
、after_train_iter
方法中选择调用evaluate
方法,进行模型验证。一般而言,EpochBasedRunner
选择在after_train_epoch
中进行验证;iterBasedRunner
则在after_train_iter
中进行验证。如果验证模型具有更高的精度,则在evaluate
中调用_save_ckpt
进行保存。pretraining
阶段使用EpochBasedRunner
,meta training
阶段使用IterBasedRunner
。 - 注册
default hoos
与custom hooks
到runner
中。 runner.resume
/runner.load_checkpoint
重启训练或读取CKPTrunner.run(data_loaders, cfg.workflow)
开启训练。注意,若cfg.workflow
为('train',1)
,此处传入的data_loaders
只包含train_dataloader
。因为3个val_dataloader
已经保存在MetaTestEvalHook
的类成员中了,并且注册到runner._hooks
。
within runnner.run()
:
开始迭代data_loader
:
self.call_hook('before_run')
while self.epoch < self._max_epochs:
-
- 根据传入的
cfg.workflow
中的 str mode,来判断当前任务的类型(train/val),是执行runner.train
还是runner.val
:epoch_runner = getattr(self, mode)
。但一般 str mode都置为train
, validate在hook中进行。
- 根据传入的
-
- 调用
epoch_runner(data_loader)
,runner.train
或runner.val
。一般都只调用runnner.train
,所以下文只讨论前者。
- 调用
self.call_hook('after_run')
within runner.train():
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(self.data_loader):
-
self.call_hook('before_train_iter')
-
self.run_iter(data_batch, train_mode=True, **kwargs)
-
self.call_hook('after_train_iter')
, 在IterBasedRunner
中,MetaTestEvalHook
会于此调用evaluate()
self.call_hook('after_train_epoch')
, 在EpochBasedRunner
中,MetaTestEvalHook
会于此调用evaluate()
within runner.run_iter(data_batch, train_mode=True):
- 根据 bool
train_mode
判断执行self.model.train_step()
/self.model.val_step()
- 保存model 的输出。