前言
mmSegmentation 中的验证流程不满足需求, 因此自定义验证流程。
验证流程在 config 文件中配置如下, 属于 schedule 配置。
evaluation = dict(interval=16000, metric='mIoU', pre_eval=True) # 每 16000 个迭代周期进行一次评估,评估指标为 mIoU
每当 interval 个 iter 训练完成后, 启用一次 evaluation。
版本信息
mmcv 1.4.8
mmsegmentation 0.23.0
自定义验证钩子(EvalHooks)
1 复制mmseg/core/evaluation/eval_hooks.py
为my_eval_hooks.py
2 在mmseg/core/evaluation/__init__.py
中将eval_hooks.py
中的EvalHook
替换为my_eval_hooks.py
中的EvalHook
3 可以啦, 在my_eval_hooks.py
的EvalHook._do_evaluate()
中修改验证流程吧。
注: EvalHook并不支持使用用配置文件注入, 因此只能用这种方式进行修改。
验证钩子生命周期
注册位置: mmseg/apis/train.py:train_segmentor()
中
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
# In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
# priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
调用位置:
EvalHooks属于after_train_hook
, 在每个训练循环后被调用。
IterBasedRunner.train()
(mmcv.runner.iter_based_runner
) 控制一个Iter的主要流程, 当其执行self.call_hook('after_train_iter')
时会调用EvalHooks的after_train_iter()
函数。
after_train_iter()
由mmcv.runner.EvalHook实现, 我们的EvalHook继承自该类, after_train_iter()
会调用self._do_evaluate()
, 因此我们将在self._do_evaluate()
中修改流程。
def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
if not self._should_evaluate(runner):
return
from mmseg.apis import single_gpu_test
results = single_gpu_test(
runner.model, self.dataloader, show=False, pre_eval=self.pre_eval)
runner.log_buffer.clear()
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)
钩子上下文
钩子统一传入runner
运行时, 主要包括:
runner.model 模型
runner.data_loader 是训练集的dataloader
self.dataloader 才是验证集的dataloader
self.dataloader.dataset 可以直接访问数据集类, 功能可以写在这里
runner.meta 包括很多运行配置如'env_info', 'seed', 'exp_name', 'mmseg_version', 'config', 'CLASSES', 'PALETTE'