7. 模型包装
wrap_model
方法用于包装模型,以支持分布式训练。如果在分布式环境中,它会将模型包装为MMDistributedDataParallel
或其他自定义的分布式数据并行模块包装器。
def wrap_model(
self, model_wrapper_cfg: Optional[Dict],
model: nn.Module) -> Union[DistributedDataParallel, nn.Module]:
if is_model_wrapper(model):
if model_wrapper_cfg is not None:
raise TypeError(
'model has been wrapped and "model_wrapper_cfg" should be None, but got {}'.format(model_wrapper_cfg)
)
return model
# 将模型移动到指定设备
model = model.to(get_device())
if not self.distributed:
self.logger.info(
'Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.'
)
model = revert_sync_batchnorm(model)
return model
else:
sync_bn = self.cfg.get('sync_bn', None)
if sync_bn is not None:
try:
model = convert_sync_batchnorm(model, sync_bn)
except ValueError as e:
self.logger.error('cfg.sync_bn should be "torch" or "mmcv", but got {}'.format(sync_bn))
raise e
if model_wrapper_cfg is None:
find_unused_parameters = self.cfg.get('find_unused_parameters', False)
model = MMDistributedDataParallel(
module=model,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters,
)
else:
model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel')
model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_cfg.get('type'))
default_args: dict = dict()
if issubclass(model_wrapper_type, DistributedDataParallel):
default_args['device_ids'] = [int(os.environ['LOCAL_RANK'])]
default_args['module'] = model
model = MODEL_WRAPPERS.build(model_wrapper_cfg, default_args=default_args)
return model
功能概述:
这个方法用于包装模型,以便在分布式训练环境中正确地处理模型。它可以根据配置将模型包装为特定的分布式数据并行模块,或者在非分布式环境下进行一些特定的处理。
参数说明:
model_wrapper_cfg
:一个字典,包含用于包装模型的配置信息。如果为None
,在分布式环境下会使用默认的配置进行包装。model
:一个nn.Module
对象,即要被包装的模型。
方法步骤解析:
-
检查模型是否已被包装:
- 如果模型已经是一个包装后的模型(通过
is_model_wrapper
函数判断),并且model_wrapper_cfg
不为None
,则抛出一个TypeError
异常,提示用户模型已经被包装,此时model_wrapper_cfg
应该为None
。
- 如果模型已经是一个包装后的模型(通过
-
非分布式环境处理:
- 如果当前不是分布式训练环境:
- 将模型移动到当前设备(通过
model.to(get_device())
实现,get_device
函数可能用于获取当前可用的设备)。 - 打印提示信息,告知用户在非分布式训练环境下,如果模型中使用了同步批归一化(SyncBatchNorm)层,将会自动转换为普通的批归一化层(BatchNormXd)。
- 使用
revert_sync_batchnorm
函数将模型中的同步批归一化层转换为普通批归一化层,并返回模型。
- 将模型移动到当前设备(通过
- 如果当前不是分布式训练环境:
-
分布式环境处理:
- 如果当前是分布式训练环境:
- 获取配置中的同步批归一化参数
sync_bn
。 - 如果
sync_bn
不为None
,尝试使用convert_sync_batchnorm
函数将模型转换为指定类型的同步批归一化模型,如果转换过程中出现错误(如不支持的sync_bn
类型),则抛出一个ValueError
异常,并打印错误信息。
- 获取配置中的同步批归一化参数
- 如果
model_wrapper_cfg
为None
:- 设置
find_unused_parameters
参数为配置中的值(如果没有配置则默认为False
)。 - 创建一个
MMDistributedDataParallel
对象来包装模型,传入模型、当前设备 ID(通过int(os.environ['LOCAL_RANK'])
获取)、广播缓冲区参数(设置为False
)以及find_unused_parameters
参数。
- 设置
- 如果
model_wrapper_cfg
不为None
:- 在配置字典中设置默认的包装类型为
'MMDistributedDataParallel'
。 - 获取配置中的包装类型对应的类(通过
MODEL_WRAPPERS.get(model_wrapper_cfg.get('type'))
实现)。 - 如果包装类型是
DistributedDataParallel
的子类,则设置一些默认参数,如设备 ID 和模型模块。 - 使用
MODEL_WRAPPERS.build
函数根据配置和默认参数构建包装后的模型。
- 在配置字典中设置默认的包装类型为
- 如果当前是分布式训练环境:
-
返回处理后的模型:
- 返回包装后的模型,在分布式环境下是一个分布式数据并行模块,在非分布式环境下可能是经过转换的普通模型。
8. 初始化模型参数
def _init_model_weights(self) -> None:
"""Initialize the model weights if the model has
:meth:`init_weights`"""
model = self.model.module if is_model_wrapper(
self.model) else self.model
if hasattr(model, 'init_weights'):
model.init_weights()
# sync params and buffers
for name, params in model.state_dict().items():
broadcast(params)
功能概述:
这个方法用于初始化模型的权重。如果模型有init_weights
方法,则调用该方法进行权重初始化,并在分布式环境下同步模型的参数和缓冲区。
方法步骤解析:
-
获取实际的模型对象:
- 如果模型被包装在一个模型包装器中(通过
is_model_wrapper
函数判断),则获取包装器中的实际模型模块(通过self.model.module
获取);否则,直接使用self.model
作为实际的模型对象。
- 如果模型被包装在一个模型包装器中(通过
-
检查并初始化权重:
- 检查模型是否具有
init_weights
方法。如果有,则调用model.init_weights()
来初始化模型的权重。这通常是特定模型类中定义的用于初始化权重的方法,可以根据模型的具体结构和需求进行定制化的权重初始化操作。
- 检查模型是否具有
-
同步参数和缓冲区(分布式环境下):
- 如果当前是分布式训练环境,遍历模型的状态字典(通过
model.state_dict()
获取),对于每个参数名和对应的参数值,使用broadcast
函数进行广播操作。这确保了在分布式训练中,所有进程的模型参数和缓冲区都保持一致。
- 如果当前是分布式训练环境,遍历模型的状态字典(通过
9. 自动缩放学习率
def scale_lr(self,
optim_wrapper: OptimWrapper,
auto_scale_lr: Optional[Dict] = None) -> None:
"""Automatically scaling learning rate in training according to the
ratio of ``base_batch_size`` in ``autoscalelr_cfg`` and real batch
size.
It scales the learning rate linearly according to the
`paper <https://arxiv.org/abs/1706.02677>`_.
Note:
``scale_lr`` must be called after building optimizer wrappers
and before building parameter schedulers.
Args:
optim_wrapper (OptimWrapper): An OptimWrapper object whose
parameter groups' learning rate need to be scaled.
auto_scale_lr (Dict, Optional): Config to scale the learning
rate automatically. It includes ``base_batch_size`` and
``enable``. ``base_batch_size`` is the batch size that the
optimizer lr is based on. ``enable`` is the switch to turn on
and off the feature.
"""
if (auto_scale_lr is None or not auto_scale_lr.get('enable', False)):
return None
assert 'base_batch_size' in auto_scale_lr, \
'Lack of `base_batch_size` in `auto_scale_lr`.'
dataloader: Union[DataLoader, Dict] = self._train_dataloader
bs = dataloader.batch_size if isinstance(
dataloader, DataLoader) else dataloader['batch_size']
real_bs = self.world_size * bs
base_bs = auto_scale_lr['base_batch_size']
ratio = float(real_bs) / float(base_bs)
self.logger.info(f'LR is set based on batch size of {base_bs} '
f'and the current batch size is {real_bs}. '
f'Scaling the original LR by {ratio}.')
def _is_built(schedulers):
if isinstance(schedulers, dict):
return False if 'type' in schedulers else any(
_is_built(s) for s in schedulers.values())
if isinstance(schedulers, list):
return any(_is_built(s) for s in schedulers)
return isinstance(schedulers, _ParamScheduler)
if _is_built(self.param_schedulers):
raise RuntimeError('`scale_lr` should be called before building '
'ParamScheduler because ParamScheduler will '
'store initial lr from optimizer wrappers')
assert isinstance(optim_wrapper, OptimWrapper), \
'`scale_lr should be called after building OptimWrapper'
wrappers = list(optim_wrapper.values()) if isinstance(
optim_wrapper, OptimWrapperDict) else [optim_wrapper]
for wrapper in wrappers:
for group in wrapper.optimizer.param_groups:
group['lr'] = group['lr'] * ratio
功能概述:
这个方法用于在训练过程中根据配置自动缩放学习率。它根据“自动缩放学习率配置”中的“基础批量大小”和实际批量大小的比例来线性地缩放学习率。
参数说明:
optim_wrapper
:一个OptimWrapper
对象,其参数组的学习率需要被缩放。auto_scale_lr
:一个可选的字典,用于配置自动缩放学习率。它包含“基础批量大小”和“启用”开关等参数。如果为None
或者“启用”开关为False
,则不进行学习率缩放操作。
方法步骤解析:
-
检查自动缩放配置是否启用:
- 如果
auto_scale_lr
为None
或者auto_scale_lr
中的“启用”开关为False
,则直接返回,不进行任何学习率缩放操作。
- 如果
-
确保存在“基础批量大小”参数:
- 检查
auto_scale_lr
中是否存在“基础批量大小”参数,如果没有,则抛出异常,提示缺少该参数。
- 检查
-
计算批量大小比例并打印信息:
- 获取训练数据加载器的批量大小。如果数据加载器是一个
DataLoader
对象,则直接获取其批量大小;如果是一个字典,则从字典中获取“batch_size”键对应的值。 - 计算实际的批量大小,即世界大小(参与训练的进程数)乘以单个进程的数据加载器批量大小。
- 根据配置中的“基础批量大小”和实际批量大小计算比例。
- 打印日志信息,告知用户学习率是基于配置的基础批量大小设置的,当前的批量大小是多少,并说明正在根据比例缩放原始学习率。
- 获取训练数据加载器的批量大小。如果数据加载器是一个
-
检查参数调度器是否已构建:
- 定义一个内部函数
_is_built
,用于检查参数调度器是否已经构建。如果参数调度器是一个字典,检查是否存在“type”键(如果存在则表示未构建),否则递归检查字典中的每个值;如果是一个列表,递归检查列表中的每个元素;如果是一个_ParamScheduler
类型的对象,则表示已构建。 - 如果参数调度器已经构建,则抛出一个
RuntimeError
异常,提示“scale_lr”应该在构建参数调度器之前被调用,因为参数调度器会从优化器包装器中存储初始学习率。
- 定义一个内部函数
-
缩放学习率:
- 确保
optim_wrapper
是一个OptimWrapper
对象。如果optim_wrapper
是一个OptimWrapperDict
类型的对象(可能是一个字典形式的优化器包装器),则将其转换为一个列表。 - 遍历优化器包装器列表中的每个包装器。
- 对于每个包装器中的优化器的每个参数组,将参数组的学习率乘以计算得到的比例,实现学习率的缩放。
- 确保
10. 优化器包装器构建
build_optim_wrapper
方法用于构建优化器包装器。它可以根据不同的配置构建单个优化器的包装器或多个优化器的包装器字典。
def build_optim_wrapper(
self, optim_wrapper: Union[Optimizer, OptimWrapper, Dict]
) -> Union[OptimWrapper, OptimWrapperDict]:
"""Build optimizer wrapper.
If ``optim_wrapper`` is a config dict for only one optimizer,
the keys must contain ``optimizer``, and ``type`` is optional.
It will build a :obj:`OptimWrapper` by default.
If ``optim_wrapper`` is a config dict for multiple optimizers, i.e.,
it has multiple keys and each key is for an optimizer wrapper. The
constructor must be specified since
:obj:`DefaultOptimizerConstructor` cannot handle the building of
training with multiple optimizers.
If ``optim_wrapper`` is a dict of pre-built optimizer wrappers, i.e.,
each value of ``optim_wrapper`` represents an ``OptimWrapper``
instance. ``build_optim_wrapper`` will directly build the
:obj:`OptimWrapperDict` instance from ``optim_wrapper``.
Args:
optim_wrapper (OptimWrapper or dict): An OptimWrapper object or a
dict to build OptimWrapper objects. If ``optim_wrapper`` is an
OptimWrapper, just return an ``OptimizeWrapper`` instance.
Note:
For single optimizer training, if `optim_wrapper` is a config
dict, `type` is optional(defaults to :obj:`OptimWrapper`) and it
must contain `optimizer` to build the corresponding optimizer.
Examples:
>>> # build an optimizer
>>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
... type='SGD', lr=0.01))
>>> # optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01))
>>> # is also valid.
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper
Type: OptimWrapper
accumulative_counts: 1
optimizer:
SGD (
Parameter Group 0
dampening: 0
lr: 0.01
momentum: 0
nesterov: False
weight_decay: 0
)
>>> # build optimizer without `type`
>>> optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01))
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper
Type: OptimWrapper
accumulative_counts: 1
optimizer:
SGD (
Parameter Group 0
dampening: 0
lr: 0.01
maximize: False
momentum: 0
nesterov: False
weight_decay: 0
)
>>> # build multiple optimizers
>>> optim_wrapper_cfg = dict(
... generator=dict(type='OptimWrapper', optimizer=dict(
... type='SGD', lr=0.01)),
... discriminator=dict(type='OptimWrapper', optimizer=dict(
... type='Adam', lr=0.001))
... # need to customize a multiple optimizer constructor
... constructor='CustomMultiOptimizerConstructor',
...)
>>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg)
>>> optim_wrapper
name: generator
Type: OptimWrapper
accumulative_counts: 1
optimizer:
SGD (
Parameter Group 0
dampening: 0
lr: 0.1
momentum: 0
nesterov: False
weight_decay: 0
)
name: discriminator
Type: OptimWrapper
accumulative_counts: 1
optimizer:
'discriminator': Adam (
Parameter Group 0
dampening: 0
lr: 0.02
momentum: 0
nesterov: False
weight_decay: 0
)
Important:
If you need to build multiple optimizers, you should implement a
MultiOptimWrapperConstructor which gets parameters passed to
corresponding optimizers and compose the ``OptimWrapperDict``.
More details about how to customize OptimizerConstructor can be
found at `optimizer-docs`_.
Returns:
OptimWrapper: Optimizer wrapper build from ``optimizer_cfg``.
""" # noqa: E501
if isinstance(optim_wrapper, OptimWrapper):
return optim_wrapper
if isinstance(optim_wrapper, (dict, ConfigDict, Config)):
# optimizer must be defined for single optimizer training.
optimizer = optim_wrapper.get('optimizer', None)
# If optimizer is a built `Optimizer` instance, the optimizer
# wrapper should be built by `OPTIM_WRAPPERS` registry.
if isinstance(optimizer, Optimizer):
optim_wrapper.setdefault('type', 'OptimWrapper')
return OPTIM_WRAPPERS.build(optim_wrapper) # type: ignore
# If `optimizer` is not None or `constructor` is defined, it means,
# optimizer wrapper will be built by optimizer wrapper
# constructor. Therefore, `build_optim_wrapper` should be called.
if optimizer is not None or 'constructor' in optim_wrapper:
return build_optim_wrapper(self.model, optim_wrapper)
else:
# if `optimizer` is not defined, it should be the case of
# training with multiple optimizers. If `constructor` is not
# defined either, each value of `optim_wrapper` must be an
# `OptimWrapper` instance since `DefaultOptimizerConstructor`
# will not handle the case of training with multiple
# optimizers. `build_optim_wrapper` will directly build the
# `OptimWrapperDict` instance from `optim_wrapper.`
optim_wrappers = OrderedDict()
for name, optim in optim_wrapper.items():
if not isinstance(optim, OptimWrapper):
raise ValueError(
'each item mush be an optimizer object when '
'"type" and "constructor" are not in '
f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers)
else:
raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}')
功能概述:
这个方法用于构建优化器包装器。它可以根据不同的输入类型构建单个优化器的包装器或者多个优化器的包装器字典。
参数说明:
optim_wrapper
:可以是一个OptimWrapper
对象、一个字典或者一个优化器对象。如果是OptimWrapper
对象,直接返回该对象;如果是字典,可以是单个优化器的配置、多个优化器的配置或者已经构建好的优化器包装器字典;如果是优化器对象,会根据配置构建优化器包装器。
方法步骤解析:
-
处理输入为
OptimWrapper
对象的情况:- 如果
optim_wrapper
是一个OptimWrapper
对象,直接返回这个对象,不进行任何进一步的构建操作。
- 如果
-
处理输入为字典的情况:
- 首先尝试从字典中获取
optimizer
键对应的值。如果存在这个值,说明可能是单个优化器的配置或者多个优化器的配置。 - 如果
optimizer
的值是一个已经构建好的Optimizer
对象:- 在字典中设置默认的键
type
为'OptimWrapper'
(如果没有设置的话)。 - 使用
OPTIM_WRAPPERS.build(optim_wrapper)
构建一个优化器包装器并返回。这里的OPTIM_WRAPPERS
可能是一个注册器,用于根据配置构建优化器包装器。
- 在字典中设置默认的键
- 如果
optimizer
的值不是None
或者字典中存在constructor
键:- 调用
build_optim_wrapper
函数(可能是一个递归调用)来构建优化器包装器。这个情况通常是为了处理需要自定义构造函数的情况,例如构建单个优化器但需要特殊的构造逻辑。
- 调用
- 如果
optimizer
的值为None
且字典中也没有constructor
键:- 这种情况被认为是训练多个优化器的情况。创建一个有序字典
optim_wrappers
。 - 遍历输入字典的键值对,对于每个键值对,如果值不是一个
OptimWrapper
对象,则抛出一个ValueError
异常,提示每个值应该是一个优化器对象。 - 如果值是一个
OptimWrapper
对象,则将其添加到optim_wrappers
中。 - 最后使用
OptimWrapperDict(**optim_wrappers)
构建一个优化器包装器字典并返回。
- 这种情况被认为是训练多个优化器的情况。创建一个有序字典
- 首先尝试从字典中获取
-
处理输入类型错误的情况:
- 如果
optim_wrapper
既不是OptimWrapper
对象也不是字典,则抛出一个TypeError
异常,提示输入的参数类型应该是OptimWrapper
对象或者字典。
- 如果
注意事项和示例:
- 对于单个优化器训练,如果
optim_wrapper
是一个字典,键type
是可选的(默认为OptimWrapper
),并且必须包含键optimizer
来构建相应的优化器。 - 如果需要构建多个优化器,应该实现一个
MultiOptimWrapperConstructor
,它获取传递给相应优化器的参数并组成OptimWrapperDict
。更多关于如何自定义优化器构造函数的细节可以在“optimizer-docs”中找到。
示例中展示了不同情况下如何构建优化器包装器,包括构建单个优化器、不指定type
键构建单个优化器以及构建多个优化器的情况。
11. 参数调度器构建
build_param_scheduler
方法用于构建参数调度器。它会根据优化器的数量和配置构建不同类型的参数调度器。如果只有一个优化器,它会返回一个参数调度器列表。如果有多个优化器,它会返回一个字典,其中每个键对应一个优化器,值是一个参数调度器列表。
def _build_param_scheduler(
self, scheduler: Union[_ParamScheduler, Dict, List],
optim_wrapper: OptimWrapper) -> List[_ParamScheduler]:
"""Build parameter schedulers for a single optimizer.
Args:
scheduler (_ParamScheduler or dict or list): A Param Scheduler
object or a dict or list of dict to build parameter schedulers.
optim_wrapper (OptimWrapper): An optimizer wrapper object is
passed to construct ParamScheduler object.
Returns:
list[_ParamScheduler]: List of parameter schedulers build from
``scheduler``.
Note:
If the train loop is built, when building parameter schedulers,
it supports setting the max epochs/iters as the default ``end``
of schedulers, and supports converting epoch-based schedulers
to iter-based according to the ``convert_to_iter_based`` key.
"""
if not isinstance(scheduler, Sequence):
schedulers = [scheduler]
else:
schedulers = scheduler
param_schedulers = []
for scheduler in schedulers:
if isinstance(scheduler, _ParamScheduler):
param_schedulers.append(scheduler)
elif isinstance(scheduler, dict):
_scheduler = copy.deepcopy(scheduler)
# Set default end
if isinstance(self._train_loop, BaseLoop):
default_end = self.max_epochs if _scheduler.get(
'by_epoch', True) else self.max_iters
_scheduler.setdefault('end', default_end)
self.logger.debug(
f'The `end` of {_scheduler["type"]} is not set. '
'Use the max epochs/iters of train loop as default.')
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(
optimizer=optim_wrapper,
epoch_length=len(self.train_dataloader))))
else:
raise TypeError(
'scheduler should be a _ParamScheduler object or dict, '
f'but got {scheduler}')
return param_schedulers
def build_param_scheduler(
self, scheduler: Union[_ParamScheduler, Dict,
List]) -> ParamSchedulerType:
"""Build parameter schedulers.
``build_param_scheduler`` should be called after
``build_optim_wrapper`` because the building logic will change
according to the number of optimizers built by the runner.
The cases are as below:
- Single optimizer: When only one optimizer is built and used in the
runner, ``build_param_scheduler`` will return a list of
parameter schedulers.
- Multiple optimizers: When two or more optimizers are built and used
in runner, ``build_param_scheduler`` will return a dict containing
the same keys with multiple optimizers and each value is a list of
parameter schedulers. Note that, if you want different optimizers to
use different parameter schedulers to update optimizer's
hyper-parameters, the input parameter ``scheduler`` also needs to be
a dict and its key are consistent with multiple optimizers.
Otherwise, the same parameter schedulers will be used to update
optimizer's hyper-parameters.
Args:
scheduler (_ParamScheduler or dict or list): A Param Scheduler
object or a dict or list of dict to build parameter schedulers.
Examples:
>>> # build one scheduler
>>> optim_cfg = dict(dict(type='SGD', lr=0.01))
>>> runner.optim_wrapper = runner.build_optim_wrapper(
>>> optim_cfg)
>>> scheduler_cfg = dict(type='MultiStepLR', milestones=[1, 2])
>>> schedulers = runner.build_param_scheduler(scheduler_cfg)
>>> schedulers
[<mmengine.optim.scheduler.lr_scheduler.MultiStepLR at 0x7f70f6966290>] # noqa: E501
>>> # build multiple schedulers
>>> scheduler_cfg = [
... dict(type='MultiStepLR', milestones=[1, 2]),
... dict(type='StepLR', step_size=1)
... ]
>>> schedulers = runner.build_param_scheduler(scheduler_cfg)
>>> schedulers
[<mmengine.optim.scheduler.lr_scheduler.MultiStepLR at 0x7f70f60dd3d0>, # noqa: E501
<mmengine.optim.scheduler.lr_scheduler.StepLR at 0x7f70f6eb6150>]
Above examples only provide the case of one optimizer and one scheduler
or multiple schedulers. If you want to know how to set parameter
scheduler when using multiple optimizers, you can find more examples
`optimizer-docs`_.
Returns:
list[_ParamScheduler] or dict[str, list[_ParamScheduler]]: List of
parameter schedulers or a dictionary contains list of parameter
schedulers build from ``scheduler``.
.. _optimizer-docs:
https://mmengine.readthedocs.io/en/latest/tutorials/optim_wrapper.html
"""
param_schedulers: ParamSchedulerType
if not isinstance(self.optim_wrapper, OptimWrapperDict):
# Since `OptimWrapperDict` inherits from `OptimWrapper`,
# `isinstance(self.optim_wrapper, OptimWrapper)` cannot tell
# whether `self.optim_wrapper` is an `OptimizerWrapper` or
# `OptimWrapperDict` instance. Therefore, here we simply check
# self.optim_wrapper is not an `OptimWrapperDict` instance and
# then assert it is an OptimWrapper instance.
assert isinstance(self.optim_wrapper, OptimWrapper), (
'`build_optimizer` should be called before'
'`build_param_scheduler` because the latter depends '
'on the former')
param_schedulers = self._build_param_scheduler(
scheduler, self.optim_wrapper) # type: ignore
return param_schedulers
else:
param_schedulers = dict()
for name, optimizer in self.optim_wrapper.items():
if isinstance(scheduler, dict) and 'type' not in scheduler:
# scheduler is a dict and each item is a ParamScheduler
# object or a config to build ParamScheduler objects
param_schedulers[name] = self._build_param_scheduler(
scheduler[name], optimizer)
else:
param_schedulers[name] = self._build_param_scheduler(
scheduler, optimizer)
return param_schedulers
_build_param_scheduler
函数:
功能概述:
这个函数用于为单个优化器构建参数调度器。它可以接受一个参数调度器对象、一个字典或者一个字典列表,并将其转换为一个参数调度器列表。
参数说明:
scheduler
:可以是一个_ParamScheduler
对象、一个字典或者一个字典列表,用于构建参数调度器。optim_wrapper
:一个优化器包装器对象,用于传递给参数调度器的构造函数。
方法步骤解析:
-
处理输入不是序列的情况:
- 如果
scheduler
不是一个序列(如列表、元组等),将其转换为一个包含单个元素的列表schedulers
。
- 如果
-
遍历调度器列表进行构建:
- 创建一个空列表
param_schedulers
用于存储构建好的参数调度器。 - 对于
schedulers
中的每个调度器:- 如果调度器已经是一个
_ParamScheduler
对象,直接将其添加到param_schedulers
中。 - 如果调度器是一个字典:
- 进行深度复制以避免修改原始字典。
- 如果训练循环已经构建,根据调度器的
by_epoch
属性设置默认的结束值(end
)为最大轮数(如果按轮数调度)或最大迭代次数(如果按迭代次数调度)。并打印日志提示用户使用了训练循环的最大轮数或迭代次数作为默认结束值。 - 使用
PARAM_SCHEDULERS.build
函数根据字典构建参数调度器,并传入优化器包装器和训练数据加载器的长度作为默认参数,然后将构建好的参数调度器添加到param_schedulers
中。
- 如果调度器既不是
_ParamScheduler
对象也不是字典,抛出一个TypeError
异常,提示输入的参数类型错误。
- 如果调度器已经是一个
- 创建一个空列表
-
返回参数调度器列表:
- 返回构建好的参数调度器列表
param_schedulers
。
- 返回构建好的参数调度器列表
注意事项:
如果训练循环已经构建,这个函数支持设置最大轮数或迭代次数作为调度器的默认结束值,并支持根据convert_to_iter_based
键将基于轮数的调度器转换为基于迭代次数的调度器。
build_param_scheduler
函数:
功能概述:
这个函数用于构建参数调度器。根据是否有多个优化器,它会返回不同类型的结果。如果只有一个优化器,返回一个参数调度器列表;如果有多个优化器,返回一个字典,其中键是优化器的名称,值是参数调度器列表。
参数说明:
scheduler
:可以是一个_ParamScheduler
对象、一个字典或者一个字典列表,用于构建参数调度器。
方法步骤解析:
- 确定参数调度器类型:
- 创建一个变量
param_schedulers
用于存储构建好的参数调度器。 - 如果优化器包装器不是一个
OptimWrapperDict
对象(即只有一个优化器):- 断言优化器包装器是一个
OptimWrapper
对象,以确保在构建参数调度器之前已经构建了优化器。 - 调用
_build_param_scheduler
函数为单个优化器构建参数调度器,并将结果赋值给param_schedulers
。 - 返回参数调度器列表
param_schedulers
。
- 断言优化器包装器是一个
- 如果优化器包装器是一个
OptimWrapperDict
对象(即有多个优化器):- 创建一个空字典
param_schedulers
用于存储每个优化器对应的参数调度器列表。 - 对于优化器包装器中的每个优化器名称和对应的优化器对象:
- 如果
scheduler
是一个字典且没有type
键:- 调用
_build_param_scheduler
函数为当前优化器构建参数调度器,传入scheduler
中对应优化器名称的调度器配置和当前优化器对象,并将结果添加到param_schedulers
字典中。
- 调用
- 否则:
- 调用
_build_param_scheduler
函数为当前优化器构建参数调度器,传入scheduler
和当前优化器对象,并将结果添加到param_schedulers
字典中。
- 调用
- 如果
- 返回参数调度器字典
param_schedulers
。
- 创建一个空字典
- 创建一个变量
注意事项和示例:
这个函数应该在构建优化器包装器之后调用,因为其构建逻辑会根据构建的优化器数量而变化。示例中展示了构建一个调度器和多个调度器的情况,以及在使用多个优化器时如何设置参数调度器的示例链接。
12. 评估器构建
build_evaluator
方法用于构建评估器。它可以根据不同的配置构建单个评估器或多个评估器的列表。
def build_evaluator(self, evaluator: Union[Dict, List, Evaluator]) -> Evaluator:
if isinstance(evaluator, Evaluator):
return evaluator
elif isinstance(evaluator, dict):
if 'metrics' in evaluator:
evaluator.setdefault('type', 'Evaluator')
return EVALUATOR.build(evaluator)
else:
return Evaluator(evaluator)
elif isinstance(evaluator, list):
return Evaluator(evaluator)
else:
raise TypeError(
'evaluator should be one of dict, list of dict, and Evaluator, but got {}'.format(evaluator)
)
功能概述:
这个方法用于构建评估器。它可以接受一个已经构建好的评估器对象、一个字典配置、一个字典列表或者一个评估器列表,并根据不同的输入类型构建相应的评估器。
参数说明:
evaluator
:可以是一个Evaluator
对象、一个字典、一个字典列表或者一个评估器列表。如果是Evaluator
对象,直接返回该对象;如果是字典,根据字典的内容构建评估器;如果是列表,根据列表中的元素构建评估器。
方法步骤解析:
-
处理输入为
Evaluator
对象的情况:- 如果
evaluator
是一个Evaluator
对象,直接返回这个对象,不进行任何进一步的构建操作。
- 如果
-
处理输入为字典的情况:
- 如果字典中包含键
metrics
,这意味着要构建自定义的评估器:- 在字典中设置默认的键
type
为'Evaluator'
(如果没有设置的话)。 - 使用
EVALUATOR.build(evaluator)
构建评估器并返回。这里的EVALUATOR
可能是一个评估器构建器,根据传入的字典配置构建评估器。
- 在字典中设置默认的键
- 如果字典中不包含键
metrics
,则构建默认的评估器:- 使用
Evaluator(evaluator)
构建评估器并返回,但由于类型注释的原因,添加了type: ignore
注释。这里的Evaluator
是一个评估器类,接受字典参数进行初始化。
- 使用
- 如果字典中包含键
-
处理输入为列表的情况:
- 构建默认的评估器:
- 使用
Evaluator(evaluator)
构建评估器并返回,但由于类型注释的原因,添加了type: ignore
注释。这里的Evaluator
是一个评估器类,接受列表参数进行初始化,列表中的元素可以是字典配置或者已经构建好的评估器对象。
- 使用
- 构建默认的评估器:
-
处理输入类型错误的情况:
- 如果
evaluator
既不是Evaluator
对象、字典也不是列表,则抛出一个TypeError
异常,提示输入的参数类型应该是字典、列表或者Evaluator
对象。
- 如果