目录
🌟 🌟MMEngine.runner 设置config参数举例
👉👉动机
基于MMEngine做模型训练,设置各种hook时,总是看不到源码,只能按照既定模式进行网络训练,要修改就得自己试参数,索性咱们就一次深挖到底,看看最底层的代码是如何写的,就不用每次猜参数了。
MMEngine 支持两种训练模式:
- 基于轮次的 EpochBased 方式
- 基于迭代次数的 IterBased 方式
这两种方式在下游算法库均有使用,例如MMDetection 默认使用 EpochBased 方式,MMSegmentation默认使用 IterBased 方式。如何修改二者的模式,看这一篇就够了。
🌟 🌟MMEngine.runner 设置config参数举例
from mmengine.runner import Runner
cfg = dict(
model=dict(type='ToyModel'),
work_dir='path/of/work_dir',
train_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=1,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=1,
num_workers=0),
test_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=1,
num_workers=0),
auto_scale_lr=dict(base_batch_size=16, enable=False),
optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict(
type='SGD', lr=0.01)),
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
val_evaluator=dict(type='ToyEvaluator'),
test_evaluator=dict(type='ToyEvaluator'),
train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1),
val_cfg=dict(),
test_cfg=dict(),
custom_hooks=[],
default_hooks=dict(
timer=dict(type='IterTimerHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
logger=dict(type='LoggerHook'),
optimizer=dict(type='OptimizerHook', grad_clip=False),
param_scheduler=dict(type='ParamSchedulerHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
log_processor=dict(window_size=20),
visualizer=dict(type='Visualizer',
vis_backends=[dict(type='LocalVisBackend',
save_dir='temp_dir')])
)
runner = Runner.from_cfg(cfg)
runner.train()
runner.test()
今天咱们主要研究train_cfg参数设置。官方给出的train_cfg参数定义为:
train_cfg (dict, optional): A dict to build a training loop. If it does not
provide "type" key, it should contain "by_epoch" to decide which type of training
loop :class:`EpochBasedTrainLoop` or :class:`IterBasedTrainLoop` should be used.
If ``train_cfg`` specified, :attr:`train_dataloader` should also be specified.
Defaults to None. See :meth:`build_train_loop` for more details.
可以看到,train_cfg 包含两种类:
- `EpochBasedTrainLoop`
- `IterBasedTrainLoop`
🎵🎵MMEngine.runner源码
class Runner:
cfg: Config
_train_loop: Optional[Union[BaseLoop, Dict]]
_val_loop: Optional[Union[BaseLoop, Dict]]
_test_loop: Optional[Union[BaseLoop, Dict]]
def __init__(
self,
model: Union[nn.Module, Dict],
work_dir: str,
train_dataloader: Optional[Union[DataLoader, Dict]] = None,
val_dataloader: Optional[Union[DataLoader, Dict]] = None,
test_dataloader: Optional[Union[DataLoader, Dict]] = None,
train_cfg: Optional[Dict] = None,
val_cfg: Optional[Dict] = None,
test_cfg: Optional[Dict] = None,
auto_scale_lr: Optional[Dict] = None,
optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None,
param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,
custom_hooks: Optional[List[Union[Hook, Dict]]] = None,
data_preprocessor: Union[nn.Module, Dict, None] = None,
load_from: Optional[str] = None,
resume: bool = False,
launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
log_processor: Optional[Dict] = None,
log_level: str = 'INFO',
visualizer: Optional[Union[Visualizer, Dict]] = None,
default_scope: str = 'mmengine',
randomness: Dict = dict(seed=None),
experiment_name: Optional[str] = None,
cfg: Optional[ConfigType] = None,
):
self._work_dir = osp.ab
......
def from_cfg(cls, cfg: ConfigType) -> 'Runner':
"""Build a runner from config.
Args:
cfg (ConfigType): A config used for building runner. Keys of
``cfg`` can see :meth:`__init__`.
Returns:
Runner: A runner build from ``cfg``.
"""
cfg = copy.deepcopy(cfg)
runner = cls(
model=cfg['model'],
work_dir=cfg['work_dir'],
train_dataloader=cfg.get('train_dataloader'),
val_dataloader=cfg.get('val_dataloader'),
test_dataloader=cfg.get('test_dataloader'),
train_cfg=cfg.get('train_cfg'),
val_cfg=cfg.get('val_cfg'),
test_cfg=cfg.get('test_cfg'),
auto_scale_lr=cfg.get('auto_scale_lr'),
optim_wrapper=cfg.get('optim_wrapper'),
param_scheduler=cfg.get('param_scheduler'),
val_evaluator=cfg.get('val_evaluator'),
test_evaluator=cfg.get('test_evaluator'),
default_hooks=cfg.get('default_hooks'),
custom_hooks=cfg.get('custom_hooks'),
data_preprocessor=cfg.get('data_preprocessor'),
load_from=cfg.get('load_from'),
resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl'))),
log_processor=cfg.get('log_processor'),
log_level=cfg.get('log_level', 'INFO'),
visualizer=cfg.get('visualizer'),
default_scope=cfg.get('default_scope', 'mmengine'),
randomness=cfg.get('randomness', dict(seed=None)),
experiment_name=cfg.get('experiment_name'),
cfg=cfg,
)
return runner
......
......
@property
def train_loop(self):
""":obj:`BaseLoop`: A loop to run training."""
if isinstance(self._train_loop, BaseLoop) or self._train_loop is None:
return self._train_loop
else:
self._train_loop = self.build_train_loop(self._train_loop)
return self._train_loop
可以看到train_loop函数的的主要参数是根据BaseLoop进行设置的,那么我们就找BaseLoop就行了。其中BaseLoop包含IterBasedTrainLoop和EpochBasedTrainLoop两种格式,也就是我们在config中传入的type参数。
🙆🙆IterBasedTrainLoop说明
🌸输入
-
runner (Runner) – A reference of runner.
-
dataloader (Dataloader or dict) – A dataloader object or a dict to build a dataloader.
-
max_iters (int) – Total training iterations.
-
val_begin (int) – The iteration that begins validating. Defaults to 1.
-
val_interval (int) – Validation interval. Defaults to 1000.
-
dynamic_intervals (List[Tuple[int, int]], optional) – The first element in the tuple is a milestone and the second element is a interval. The interval is used after the corresponding milestone. Defaults to None.
🌸输出
- None
🌸IterBasedTrainLoop源码
@LOOPS.register_module()
class IterBasedTrainLoop(BaseLoop):
"""Loop for iter-based training.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
max_iters (int): Total training iterations.
val_begin (int): The iteration that begins validating.
Defaults to 1.
val_interval (int): Validation interval. Defaults to 1000.
dynamic_intervals (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
"""
def __init__(
self,
runner,
dataloader: Union[DataLoader, Dict],
max_iters: int,
val_begin: int = 1,
val_interval: int = 1000,
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader)
self._max_iters = int(max_iters)
assert self._max_iters == max_iters, \
f'`max_iters` should be a integer number, but get {max_iters}'
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
self._epoch = 0
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
# This attribute will be updated by `EarlyStoppingHook`
# when it is enabled.
self.stop_training = False
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
print_log(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in visualizer will be '
'None.',
logger='current',
level=logging.WARNING)
# get the iterator of the dataloader
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
self.val_interval, dynamic_intervals)
@property
def max_epochs(self):
"""int: Total epochs to train model."""
return self._max_epochs
@property
def max_iters(self):
"""int: Total iterations to train model."""
return self._max_iters
@property
def epoch(self):
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
"""int: Current iteration."""
return self._iter
def run(self) -> None:
"""Launch training."""
self.runner.call_hook('before_train')
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()
data_batch = next(self.dataloader_iterator)
self.run_iter(data_batch)
self._decide_current_val_interval()
if (self.runner.val_loop is not None
and self._iter >= self.val_begin
and self._iter % self.val_interval == 0):
self.runner.val_loop.run()
self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train')
return self.runner.model
def run_iter(self, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch.
Args:
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
# Enable gradient accumulation mode and avoid unnecessary gradient
# synchronization during gradient accumulation process.
# outputs should be a dict of loss.
outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.call_hook(
'after_train_iter',
batch_idx=self._iter,
data_batch=data_batch,
outputs=outputs)
self._iter += 1
def _decide_current_val_interval(self) -> None:
"""Dynamically modify the ``val_interval``."""
step = bisect.bisect(self.dynamic_milestones, (self._iter + 1))
self.val_interval = self.dynamic_intervals[step - 1]
🙆🙆EpochBasedTrainLoop说明
🌸输入
-
runner (Runner) – A reference of runner.
-
dataloader (Dataloader or dict) – A dataloader object or a dict to build a dataloader.
-
max_epochs (int) – Total training epochs.
-
val_begin (int) – The epoch that begins validating. Defaults to 1.
-
val_interval (int) – Validation interval. Defaults to 1.
-
dynamic_intervals (List[Tuple[int, int]], optional) – The first element in the tuple is a milestone and the second element is a interval. The interval is used after the corresponding milestone. Defaults to None.
🌸输出
- None
🌸EpochBasedTrainLoop源码
class EpochBasedTrainLoop(BaseLoop):
"""Loop for epoch-based training.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
max_epochs (int): Total training epochs.
val_begin (int): The epoch that begins validating.
Defaults to 1.
val_interval (int): Validation interval. Defaults to 1.
dynamic_intervals (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
"""
def __init__(
self,
runner,
dataloader: Union[DataLoader, Dict],
max_epochs: int,
val_begin: int = 1,
val_interval: int = 1,
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader)
self._max_epochs = int(max_epochs)
assert self._max_epochs == max_epochs, \
f'`max_epochs` should be a integer number, but get {max_epochs}.'
self._max_iters = self._max_epochs * len(self.dataloader)
self._epoch = 0
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
# This attribute will be updated by `EarlyStoppingHook`
# when it is enabled.
self.stop_training = False
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
else:
print_log(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in visualizer will be '
'None.',
logger='current',
level=logging.WARNING)
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
self.val_interval, dynamic_intervals)
@property
def max_epochs(self):
"""int: Total epochs to train model."""
return self._max_epochs
@property
def max_iters(self):
"""int: Total iterations to train model."""
return self._max_iters
@property
def epoch(self):
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
"""int: Current iteration."""
return self._iter
def run(self) -> torch.nn.Module:
"""Launch training."""
self.runner.call_hook('before_train')
while self._epoch < self._max_epochs and not self.stop_training:
self.run_epoch()
self._decide_current_val_interval()
if (self.runner.val_loop is not None
and self._epoch >= self.val_begin
and self._epoch % self.val_interval == 0):
self.runner.val_loop.run()
self.runner.call_hook('after_train')
return self.runner.model
def run_epoch(self) -> None:
"""Iterate one epoch."""
self.runner.call_hook('before_train_epoch')
self.runner.model.train()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
self.runner.call_hook('after_train_epoch')
self._epoch += 1
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
"""Iterate one min-batch.
Args:
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_train_iter', batch_idx=idx, data_batch=data_batch)
# Enable gradient accumulation mode and avoid unnecessary gradient
# synchronization during gradient accumulation process.
# outputs should be a dict of loss.
outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.call_hook(
'after_train_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
self._iter += 1
def _decide_current_val_interval(self) -> None:
"""Dynamically modify the ``val_interval``."""
step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1))
self.val_interval = self.dynamic_intervals[step - 1]
🔥🔥总结
config文件的train_cfg只有两种训练模式,一种是基于迭代次数,另一种是基于轮数,其中设置参数为一下两种方式。
👍基于迭代次数训练
❤️config
train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
❤️参数说明
- type:训练类型。
- max_iters:最大训练迭代次数,即达到80000次迭代结束训练。
- val_interval:验证迭代次数,即每4000次迭代计算一次验证。
👍基于轮数训练
❤️config
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=200, val_interval=1)
❤️参数说明
- type:训练类型。
- max_iters:最大训练轮数,即达到200轮结束训练。
- val_interval:验证轮数,即每1轮计算一次验证。
✌️✌️启发
虽然得到的结论很简单,只有两种不同训练方式的参数设置说明,但是中间的巧妙训练设计源码,没事看看也是一种“巧夺天工的美文”
完整mmengine源码:链接
整理不易,欢迎一键三连!!!
送你们一条美丽的--分割线--
🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸⛵⛵⭐⭐🍎🍎👍👍🌷🌷