MMDetection 系列之高级指南(自定义运行设置之优化器、其他设置、学习率、工作流、有用的和自定义钩子)

1、HOOK

Hook programming是一种在程序的一个或多个位置设置挂载点的编程模式。当程序运行到某个挂载点时,会自动调用在运行时注册到该挂载点的所有方法。钩子编程可以增加程序的灵活性和扩展性,因为用户可以将自定义方法注册到要调用的挂载点,而无需修改程序中的代码。

1、内置挂钩

MMEngine 将许多实用程序封装为内置挂钩。这些钩子分为两类,即默认钩子和自定义钩子。前者是指默认注册到Runner的,后者是指用户按需注册的。

每个钩子都有相应的优先级。在每个挂载点,优先级较高的钩子会先被Runner. 当共享相同的优先级时,钩子按照它们的注册顺序被调用。优先级列表如下。

HIGHEST (0)
VERY_HIGH (10)
HIGH (30)
ABOVE_NORMAL (40)
NORMAL (50)
BELOW_NORMAL (60)
LOW (70)
VERY_LOW (90)
LOWEST (100)

1、 default hooks(默认hooks)

在这里插入图片描述

2、 custom hooks(自定义hooks)

在这里插入图片描述注意:不建议修改default hooks的优先级,因为较低优先级的钩子可能依赖于较高优先级的钩子。例如,CheckpointHook需要具有比 ParamSchedulerHook 更低的优先级,以便保存的优化器状态正确。此外,自定义挂钩的优先级默认为NORMAL (50)

两种类型的 hooks 在 Runner 中的设置不同,default hooks的配置传递给default_hooks parameter of the Runner ,custom_hooks的配置传递给参数custom_hooks,如下。

from mmengine.runner import Runner
default_hooks = dict(
    runtime_info=dict(type='RuntimeInfoHook'),
    timer=dict(type='IterTimerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    logger=dict(type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=1),
)
custom_hooks = [dict(type='EmptyCacheHook')]
runner = Runner(default_hooks=default_hooks, custom_hooks=custom_hooks, ...)
runner.train()

3、Hook介绍

1、 LoggerHook

LoggerHook从Runner不同组件收集日志并将其写入终端、JSON文件、tensorboard和wandb等。
如果我们想每20次迭代输出(或保存)日志,我们可以设置参数interval并配置如下

default_hooks = dict(logger=dict(type='LoggerHook', interval=20))

如果您对 MMEngine 如何管理日志记录感兴趣,可以参考日志记录

2、 CheckpointHook

CheckpointHook以给定的时间间隔保存检查点。在分布式训练的情况下,只有主进程才会保存检查点。其主要特点CheckpointHook如下。

按时间间隔保存检查点,支持按epoch或iteration保存
保存最近的检查点
保存最佳检查点
指定检查点保存路径
设置发布检查点
控制检查点保存开始的 epoch数或迭代数

更多功能请阅读CheckpointHook API文档

下面分别对上述六大特点进行介绍。

(1)按时间间隔保存检查点,支持按epoch或迭代保存
interval (int) – The saving period. If by_epoch=True, interval indicates epochs, otherwise it indicates iterations. Defaults to -1, which means “never”.

假设我们总共训练了 20 个 epoch,并且想要每 5 个 epoch 保存一次检查点,下面的配置将帮助我们实现这个要求。

# the default value of by_epoch is True
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True))

(2)保存最近的检查点

如果只想保留一定数量的检查点,可以设置该max_keep_ckpts参数。当保存的检查点数量超过max_keep_ckpts 时,之前的检查点将被删除。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=2))

上面的配置显示,如果总共训练 20 个 epoch,则模型将在 epoch 5、10、15 和 20 保存,但检查点将在 epoch 15 时删除,在 epoch 20 时检查点将被epoch_5.pth删除epoch_10.pth,这样只有epoch_15.pth和epoch_20.pth会被保存。

(3)保存最佳检查点
save_best (str, List[str], optional) – If a metric is specified, it would measure the best checkpoint during evaluation. If a list of metrics is passed, it would measure a group of best checkpoints corresponding to the passed metrics. The information about best checkpoint(s) would be saved in runner.message_hub to keep best score value and best checkpoint path, which will be also loaded when resuming checkpoint. Options are the evaluation metrics on the test dataset. e.g., bbox_mAP, segm_mAP for bbox detection and instance segmentation. AR@100 for proposal recall. If save_best is auto, the first key of the returned OrderedDict result will be used. Defaults to None

如果想保存验证集的最佳检查点用于训练过程,可以设置该save_best参数。如果设置为’auto’,则根据验证集的第一个评估指标(评估器返回的评估指标是有序字典)判断当前检查点是最佳的

default_hooks = dict(checkpoint=dict(type='CheckpointHook', save_best='auto'))

也可以直接指定save_best 的值作为评估指标,例如在分类任务中,可以指定save_best=‘top-1’,则当前检查点将根据最佳’top-1’ 的值判断为best。

除了该save_best参数之外,与保存最佳检查点相关的其他参数还有rule、greater_keys和less_keys,用于暗示值越大越好。例如,如果指定save_best=‘top-1’,则可以指定rule='greater’表示值越大,检查点越好。

(4**)指定检查点保存路径**

默认情况下保存检查点work_dir,但可以通过设置更改路径out_dir。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))

(5)设置发布检查点

如果您想在训练后自动生成可发布的检查点(删除不必要的键,例如优化器状态),您可以设置参数published_keys来选择保留哪些信息。注意:您需要相应地设置save_best或save_last参数,以便生成可释放的检查点。设置save_best将生成最优检查点的可释放权重,设置save_last将生成可释放的最终检查点。这两个参数也可以同时设置。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))

(6)控制检查点保存开始的纪元数或迭代数

如果要设置epoch或迭代的数量来控制保存权重的开始,可以设置该save_begin参数,默认为0,表示从训练开始就保存检查点。例如,如果您总共训练了 10 个 epoch,并且save_begin设置为 5,则将保存 epoch 5、6、7、8、9 和 10 的检查点。如果interval=2,则仅保存 epoch 5、7 和 9 的检查点。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5))
3、ParamSchedulerHook

ParamSchedulerHook会迭代 Runner 的所有优化器参数调度器,并调用它们的step方法按顺序更新优化器参数。有关参数调度程序的更多详细信息,请参阅参数调度程序

ParamSchedulerHook默认注册到Runner上,没有可配置的参数,所以不需要配置。

4、 IterTimerHook

IterTimerHook用于记录加载数据并迭代一次所花费的时间。
IterTimerHook默认注册到Runner上,没有可配置的参数,所以不需要配置。

5、DistSamplerSeedHook

DistSamplerSeedHook在分布式训练时调用Sampler的step方法来保证shuffle操作生效。
DistSamplerSeedHook默认注册到Runner上,没有可配置的参数,所以不需要配置。

6、 RuntimeInfoHook

RuntimeInfoHook 会将当前运行时信息(例如epoch、iter、max_epochs、max_iters、lr、metrics等)更新到Runner中不同挂载点的消息中心,以便其他无法访问Runner的模块可以获得这些信息。
RuntimeInfoHook默认注册到Runner上,没有可配置的参数,所以不需要配置。

7、 EMAHook

EMAHook在对模型训练时进行指数移动平均运算,目的是提高模型的鲁棒性。请注意,指数移动平均生成的模型仅用于验证和测试,不影响训练。

custom_hooks = [dict(type='EMAHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

EMAHook默认使用ExponentialMovingAverage,可选值为StochasticWeightAverage和MomentumAnnealingEMA。通过设置ema_type可以使用其他平均策略。

custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]

更多用法请参见EMAHook API 参考

8、EmptyCacheHook

EmptyCacheHook调用torch.cuda.empty_cache()释放所有未占用的缓存 GPU 内存。释放内存的时间可以通过设置before_epoch、after_iter、 和after_epoch 等,分别表示每个 epoch 开始之前、每次迭代之后和每个 epoch 之后。

# The release operation is performed at the end of each epoch
custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
9、 SyncBuffersHook

SyncBuffersHook在分布式训练期间在每个时期结束时同步模型的缓冲区,例如running_mean和BN层的running_var缓冲区。

SyncBuffersHook synchronizes the buffer of the model at the end of each epoch during distributed training, e.g. running_mean and running_var of the BN layer.

custom_hooks = [dict(type='SyncBuffersHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
10、 DetVisualizationHook

DetVisualizationHook用DetLocalVisualizer可视化预测结果,DetLocalVisualizer当前支持不同的后端,例如TensorboardVisBackend和WandbVisBackend(有关更多详细信息,请参阅文档字符串vis_backend)。用户可以添加多个后端来进行可视化,如下所示。

default_hooks = dict(
    visualization=dict(type='DetVisualizationHook', draw=True))

vis_backends = [dict(type='LocalVisBackend'),
                dict(type='TensorboardVisBackend')]
visualizer = dict(
    type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')

2、 Customize hooks(自定义钩子)

1、Implement a new hook(2.3.0)

在某些情况下,用户可能需要实现一个新的钩子。MMDetection从2.3.0版本开始支持在training(#3395)中定制钩子。因此,用户可以直接在mmdet或基于mmdet的代码库中实现钩子,并且只需要在培训中修改配置即可使用钩子。在v2.3.0之前,用户需要在培训开始之前修改代码来注册钩子。这里我们给出一个在mmdet中创建一个新钩子并在培训中使用它的例子。

from mmcv.runner import HOOKS, Hook
@HOOKS.register_module()
class MyHook(Hook):

    def __init__(self, a, b):
        pass

    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

v3:如果 MMEngine 提供的内置钩子不能满足您的需求,我们鼓励您通过简单地继承钩子基类并重写相应的挂载点方法来自定义自己的钩子。例如,如果您想在训练期间检查损失值是否有效(即不是无限),您可以简单地重写after_train_iter以下方法。检查将在每次训练迭代后执行。

import torch
from mmengine.registry import HOOKS
from mmengine.hooks import Hook
@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
    """Check invalid loss hook.
    This hook will regularly check whether the loss is valid
    during training.
    Args:
        interval (int): Checking interval (every k iterations).
            Defaults to 50.
    """
    def __init__(self, interval=50):
        self.interval = interval
    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        """All subclasses should override this method, if they need any
        operations after each training iteration.
        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict or tuple or list, optional): Data from dataloader.
            outputs (dict, optional): Outputs from model.
        """
        if self.every_n_train_iters(runner, self.interval):
            assert torch.isfinite(outputs['loss']),\
                runner.logger.info('loss become infinite or NaN!')

根据钩子的功能,用户需要在before_run, after_run, before_train, after_train , before_train_epoch, after_train_epoch, before_train_iter, and after_train_iter中指定钩子在训练的每个阶段将执行的操作。还有更多可以插入钩子的点,请参阅基本钩子类以获取更多详细信息。

在这里插入图片描述

# Copyright (c) OpenMMLab. All rights reserved.
import math

from mmcv.parallel import is_module_wrapper
from mmcv.runner.hooks import HOOKS, Hook


class BaseEMAHook(Hook):
    """Exponential Moving Average Hook.
    Use Exponential Moving Average on all parameters of model in training
    process. All parameters have a ema backup, which update by the formula
    as below. EMAHook takes priority over EvalHook and CheckpointHook. Note,
    the original model parameters are actually saved in ema field after train.
    Args:
        momentum (float): The momentum used for updating ema parameter.
            Ema's parameter are updated with the formula:
           `ema_param = (1-momentum) * ema_param + momentum * cur_param`.
            Defaults to 0.0002.
        skip_buffers (bool): Whether to skip the model buffers, such as
            batchnorm running stats (running_mean, running_var), it does not
            perform the ema operation. Default to False.
        interval (int): Update ema parameter every interval iteration.
            Defaults to 1.
        resume_from (str, optional): The checkpoint path. Defaults to None.
        momentum_fun (func, optional): The function to change momentum
            during early iteration (also warmup) to help early training.
            It uses `momentum` as a constant. Defaults to None.
    """

    def __init__(self,
                 momentum=0.0002,
                 interval=1,
                 skip_buffers=False,
                 resume_from=None,
                 momentum_fun=None):
        assert 0 < momentum < 1
        self.momentum = momentum
        self.skip_buffers = skip_buffers
        self.interval = interval
        self.checkpoint = resume_from
        self.momentum_fun = momentum_fun

    def before_run(self, runner):
        """To resume model with it's ema parameters more friendly.
        Register ema parameter as ``named_buffer`` to model.
        """
        model = runner.model
        if is_module_wrapper(model):
            model = model.module
        self.param_ema_buffer = {}
        if self.skip_buffers:
            self.model_parameters = dict(model.named_parameters())
        else:
            self.model_parameters = model.state_dict()
        for name, value in self.model_parameters.items():
            # "." is not allowed in module's buffer name
            buffer_name = f"ema_{name.replace('.', '_')}"
            self.param_ema_buffer[name] = buffer_name
            model.register_buffer(buffer_name, value.data.clone())
        self.model_buffers = dict(model.named_buffers())
        if self.checkpoint is not None:
            runner.resume(self.checkpoint)

    def get_momentum(self, runner):
        return self.momentum_fun(runner.iter) if self.momentum_fun else \
                        self.momentum

    def after_train_iter(self, runner):
        """Update ema parameter every self.interval iterations."""
        if (runner.iter + 1) % self.interval != 0:
            return
        momentum = self.get_momentum(runner)
        for name, parameter in self.model_parameters.items():
            # exclude num_tracking
            if parameter.dtype.is_floating_point:
                buffer_name = self.param_ema_buffer[name]
                buffer_parameter = self.model_buffers[buffer_name]
                buffer_parameter.mul_(1 - momentum).add_(
                    parameter.data, alpha=momentum)

    def after_train_epoch(self, runner):
        """We load parameter values from ema backup to model before the
        EvalHook."""
        self._swap_ema_parameters()

    def before_train_epoch(self, runner):
        """We recover model's parameter from ema backup after last epoch's
        EvalHook."""
        self._swap_ema_parameters()

    def _swap_ema_parameters(self):
        """Swap the parameter of model with parameter in ema_buffer."""
        for name, value in self.model_parameters.items():
            temp = value.data.clone()
            ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
            value.data.copy_(ema_buffer.data)
            ema_buffer.data.copy_(temp)


@HOOKS.register_module()
class ExpMomentumEMAHook(BaseEMAHook):
    """EMAHook using exponential momentum strategy.
    Args:
        total_iter (int): The total number of iterations of EMA momentum.
           Defaults to 2000.
    """

    def __init__(self, total_iter=2000, **kwargs):
        super(ExpMomentumEMAHook, self).__init__(**kwargs)
        self.momentum_fun = lambda x: (1 - self.momentum) * math.exp(-(
            1 + x) / total_iter) + self.momentum


@HOOKS.register_module()
class LinearMomentumEMAHook(BaseEMAHook):
    """EMAHook using linear momentum strategy.
    Args:
        warm_up (int): During first warm_up steps, we may use smaller decay
            to update ema parameters more slowly. Defaults to 100.
    """

    def __init__(self, warm_up=100, **kwargs):
        super(LinearMomentumEMAHook, self).__init__(**kwargs)
        self.momentum_fun = lambda x: min(self.momentum**self.interval,
                                          (1 + x) / (warm_up + x))

2. Register the new hook(注册新钩子)

然后我们需要导入MyHook。假设文件在mmdet/core/utils/my_hook.py中,有两种方法:

Modify mmdet/core/utils/init.py to import it.

The newly defined module should be imported in mmdet/core/utils/init.py so that the registry will find the new module and add it:

from .my_hook import MyHook

Use custom_imports in the config to manually import it

custom_imports = dict(imports=['mmdet.core.utils.my_hook'], allow_failed_imports=False)

custom_imports = dict(imports=['mmdet.engine.hooks.my_hook'], allow_failed_imports=False)

在这里插入图片描述

# Copyright (c) OpenMMLab. All rights reserved.
from .checkloss_hook import CheckInvalidLossHook
from .ema import ExpMomentumEMAHook, LinearMomentumEMAHook(导入了构造的HOOKS)
from .memory_profiler_hook import MemoryProfilerHook
from .set_epoch_info_hook import SetEpochInfoHook
from .sync_norm_hook import SyncNormHook
from .sync_random_size_hook import SyncRandomSizeHook
from .wandblogger_hook import MMDetWandbHook
from .yolox_lrupdater_hook import YOLOXLrUpdaterHook
from .yolox_mode_switch_hook import YOLOXModeSwitchHook

__all__ = [
    'SyncRandomSizeHook', 'YOLOXModeSwitchHook', 'SyncNormHook',
    'ExpMomentumEMAHook', 'LinearMomentumEMAHook', 'YOLOXLrUpdaterHook',
    'CheckInvalidLossHook', 'SetEpochInfoHook', 'MemoryProfilerHook',
    'MMDetWandbHook'
]

V3: 我们只需将钩子配置传递给Runner custom_hooks的参数,它将在 Runner 初始化时注册钩子

from mmengine.runner import Runner
custom_hooks = [
    dict(type='CheckInvalidLossHook', interval=50)
]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()  # start training

然后在迭代后检查损失值。

注意,自定义钩子的优先级是默认的,如果你想改变钩子的优先级,那么你可以在config中设置priority键。NORMAL (50)

custom_hooks = [
    dict(type='CheckInvalidLossHook', interval=50, priority='ABOVE_NORMAL')
]

您还可以在定义类时设置优先级。

@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
    priority = 'ABOVE_NORMAL'

3. Modify the config(修改配置文件)

custom_hooks = [
    dict(type='MyHook', a=a_value, b=b_value)
]

你也可以设置钩子的优先级,将键的优先级设置为“NORMAL”或“HIGHEST”,如下所示

custom_hooks = [
    dict(type='MyHook', a=a_value, b=b_value, priority='NORMAL')
]

By default the hook’s priority is set as NORMAL during registration.

custom_hooks = [dict(
                type='LinearMomentumEMAHook',
                momentum=0.001,
                interval=1,
                warm_up=100,
                resume_from=None
                )]
YOLOX配置文件调用了
custom_hooks = [
    dict(
        type='YOLOXModeSwitchHook',
        num_last_epochs=num_last_epochs,
        priority=48),
    dict(
        type='SyncNormHook',
        num_last_epochs=num_last_epochs,
        interval=interval,
        priority=48),
    dict(
        type='ExpMomentumEMAHook',
        resume_from=resume_from,
        momentum=0.0001,
        priority=49)
]

3、 在MMCV中使用hook

如果该hook已经在MMCV中实现,你可以直接修改配置来使用该hook,如下所示
修改默认运行时hook
Example: NumClassCheckHook
我们实现了一个名为NumClassCheckHook的自定义hook来检查head中的num_classes是否与dataset中的CLASSSES长度匹配。
We set it in default_runtime.py.

custom_hooks = [dict(type='NumClassCheckHook')]

4、Modify default runtime hooks(修改默认的运行时钩子)

log_config
checkpoint_config
evaluation
lr_config
optimizer_config
momentum_config

在这些hook中,只有logger hook 具有VERY_LOw优先级,其他hook的优先级为NORMAL。上述教程已经介绍了如何修改optimizer_config、momentum_config和lr_config。在这里,我们揭示了如何使用log_config、checkpoint_config和evaluate来做什么。

1、 Checkpoint config

MMCV运行器将使用checkpoint_config来初始化CheckpointHook.CheckpointHook

checkpoint_config = dict(interval=1)
@HOOKS.register_module()
class CheckpointHook(Hook):
    """Save checkpoints periodically.

    Args:
        interval (int): The saving period. If ``by_epoch=True``, interval
            indicates epochs, otherwise it indicates iterations.
            Defaults to -1, which means "never".
        by_epoch (bool): Saving checkpoints by epoch or by iteration.
            Defaults to True.
        save_optimizer (bool): Whether to save optimizer state_dict in the
            checkpoint. It is usually used for resuming experiments.
            Defaults to True.
        save_param_scheduler (bool): Whether to save param_scheduler state_dict
            in the checkpoint. It is usually used for resuming experiments.
            Defaults to True.
        out_dir (str, Path, Optional): The root directory to save checkpoints.
            If not specified, ``runner.work_dir`` will be used by default. If
            specified, the ``out_dir`` will be the concatenation of ``out_dir``
            and the last level directory of ``runner.work_dir``. For example,
            if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is
            ``./work_dir/cur_exp``, then the ckpt will be saved in
            ``./tmp/cur_exp``. Defaults to None.
        max_keep_ckpts (int): The maximum checkpoints to keep.
            In some cases we want only the latest few checkpoints and would
            like to delete old ones to save the disk space.
            Defaults to -1, which means unlimited.
        save_last (bool): Whether to force the last checkpoint to be
            saved regardless of interval. Defaults to True.
        save_best (str, List[str], optional): If a metric is specified, it
            would measure the best checkpoint during evaluation. If a list of
            metrics is passed, it would measure a group of best checkpoints
            corresponding to the passed metrics. The information about best
            checkpoint(s) would be saved in ``runner.message_hub`` to keep
            best score value and best checkpoint path, which will be also
            loaded when resuming checkpoint. Options are the evaluation metrics
            on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
            detection and instance segmentation. ``AR@100`` for proposal
            recall. If ``save_best`` is ``auto``, the first key of the returned
            ``OrderedDict`` result will be used. Defaults to None.
        rule (str, List[str], optional): Comparison rule for best score. If
            set to None, it will infer a reasonable rule. Keys such as 'acc',
            'top' .etc will be inferred by 'greater' rule. Keys contain 'loss'
            will be inferred by 'less' rule. If ``save_best`` is a list of
            metrics and ``rule`` is a str, all metrics in ``save_best`` will
            share the comparison rule. If ``save_best`` and ``rule`` are both
            lists, their length must be the same, and metrics in ``save_best``
            will use the corresponding comparison rule in ``rule``. Options
            are 'greater', 'less', None and list which contains 'greater' and
            'less'. Defaults to None.
        greater_keys (List[str], optional): Metric keys that will be
            inferred by 'greater' comparison rule. If ``None``,
            _default_greater_keys will be used. Defaults to None.
        less_keys (List[str], optional): Metric keys that will be
            inferred by 'less' comparison rule. If ``None``, _default_less_keys
            will be used. Defaults to None.
        file_client_args (dict, optional): Arguments to instantiate a
            FileClient. See :class:`mmengine.fileio.FileClient` for details.
            Defaults to None. It will be deprecated in future. Please use
            ``backend_args`` instead.
        filename_tmpl (str, optional): String template to indicate checkpoint
            name. If specified, must contain one and only one "{}", which will
            be replaced with ``epoch + 1`` if ``by_epoch=True`` else
            ``iteration + 1``.
            Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth"
            accordingly.
        backend_args (dict, optional): Arguments to instantiate the
            prefix of uri corresponding backend. Defaults to None.
            `New in version 0.2.0.`
        published_keys (str, List[str], optional): If ``save_last`` is ``True``
            or ``save_best`` is not ``None``, it will automatically
            publish model with keys in the list after training.
            Defaults to None.
            `New in version 0.7.1.`
        save_begin (int): Control the epoch number or iteration number
            at which checkpoint saving begins. Defaults to 0, which means
            saving at the beginning.
            `New in version 0.8.3.`

    Examples:
        >>> # Save best based on single metric
        >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
        >>>                rule='less')
        >>> # Save best based on multi metrics with the same comparison rule
        >>> CheckpointHook(interval=2, by_epoch=True,
        >>>                save_best=['acc', 'mIoU'], rule='greater')
        >>> # Save best based on multi metrics with different comparison rule
        >>> CheckpointHook(interval=2, by_epoch=True,
        >>>                save_best=['FID', 'IS'], rule=['less', 'greater'])
        >>> # Save best based on single metric and publish model after training
        >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
        >>>                rule='less', published_keys=['meta', 'state_dict'])
    """
    out_dir: str

    priority = 'VERY_LOW'

    # logic to save best checkpoints
    # Since the key for determining greater or less is related to the
    # downstream tasks, downstream repositories may need to overwrite
    # the following inner variables accordingly.

    rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
    init_value_map = {'greater': -inf, 'less': inf}
    _default_greater_keys = [
        'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
        'mAcc', 'aAcc'
    ]
    _default_less_keys = ['loss']

    def __init__(self,
                 interval: int = -1,
                 by_epoch: bool = True,
                 save_optimizer: bool = True,
                 save_param_scheduler: bool = True,
                 out_dir: Optional[Union[str, Path]] = None,
                 max_keep_ckpts: int = -1,
                 save_last: bool = True,
                 save_best: Union[str, List[str], None] = None,
                 rule: Union[str, List[str], None] = None,
                 greater_keys: Optional[Sequence[str]] = None,
                 less_keys: Optional[Sequence[str]] = None,
                 file_client_args: Optional[dict] = None,
                 filename_tmpl: Optional[str] = None,
                 backend_args: Optional[dict] = None,
                 published_keys: Union[str, List[str], None] = None,
                 save_begin: int = 0,
                 **kwargs) -> None:
        self.interval = interval
        self.by_epoch = by_epoch
        self.save_optimizer = save_optimizer
        self.save_param_scheduler = save_param_scheduler
        self.out_dir = out_dir  # type: ignore
        self.max_keep_ckpts = max_keep_ckpts
        self.save_last = save_last
        self.args = kwargs

        if file_client_args is not None:
            print_log(
                '"file_client_args" will be deprecated in future. '
                'Please use "backend_args" instead',
                logger='current',
                level=logging.WARNING)
            if backend_args is not None:
                raise ValueError(
                    '"file_client_args" and "backend_args" cannot be set '
                    'at the same time.')

        self.file_client_args = file_client_args
        self.backend_args = backend_args

        if filename_tmpl is None:
            if self.by_epoch:
                self.filename_tmpl = 'epoch_{}.pth'
            else:
                self.filename_tmpl = 'iter_{}.pth'
        else:
            self.filename_tmpl = filename_tmpl

        # save best logic
        assert (isinstance(save_best, str) or is_list_of(save_best, str)
                or (save_best is None)), (
                    '"save_best" should be a str or list of str or None, '
                    f'but got {type(save_best)}')

        if isinstance(save_best, list):
            if 'auto' in save_best:
                assert len(save_best) == 1, (
                    'Only support one "auto" in "save_best" list.')
            assert len(save_best) == len(
                set(save_best)), ('Find duplicate element in "save_best".')
        else:
            # convert str to list[str]
            if save_best is not None:
                save_best = [save_best]  # type: ignore # noqa: F401
        self.save_best = save_best

        # rule logic
        assert (isinstance(rule, str) or is_list_of(rule, str)
                or (rule is None)), (
                    '"rule" should be a str or list of str or None, '
                    f'but got {type(rule)}')
        if isinstance(rule, list):
            # check the length of rule list
            assert len(rule) in [
                1,
                len(self.save_best)  # type: ignore
            ], ('Number of "rule" must be 1 or the same as number of '
                f'"save_best", but got {len(rule)}.')
        else:
            # convert str/None to list
            rule = [rule]  # type: ignore # noqa: F401

        if greater_keys is None:
            self.greater_keys = self._default_greater_keys
        else:
            if not isinstance(greater_keys, (list, tuple)):
                greater_keys = (greater_keys, )  # type: ignore
            assert is_seq_of(greater_keys, str)
            self.greater_keys = greater_keys  # type: ignore

        if less_keys is None:
            self.less_keys = self._default_less_keys
        else:
            if not isinstance(less_keys, (list, tuple)):
                less_keys = (less_keys, )  # type: ignore
            assert is_seq_of(less_keys, str)
            self.less_keys = less_keys  # type: ignore

        if self.save_best is not None:
            self.is_better_than: Dict[str, Callable] = dict()
            self._init_rule(rule, self.save_best)
            if len(self.key_indicators) == 1:
                self.best_ckpt_path: Optional[str] = None
            else:
                self.best_ckpt_path_dict: Dict = dict()

        # published keys
        if not (isinstance(published_keys, str)
                or is_seq_of(published_keys, str) or published_keys is None):
            raise TypeError(
                '"published_keys" should be a str or a sequence of str or '
                f'None, but got {type(published_keys)}')

        if isinstance(published_keys, str):
            published_keys = [published_keys]
        elif isinstance(published_keys, (list, tuple)):
            assert len(published_keys) == len(set(published_keys)), (
                'Find duplicate elements in "published_keys".')
        self.published_keys = published_keys

        self.last_ckpt = None
        if save_begin < 0:
            raise ValueError(
                'save_begin should not be less than 0, but got {save_begin}')
        self.save_begin = save_begin

    def before_train(self, runner) -> None:
        """Finish all operations, related to checkpoint.

        This function will get the appropriate file client, and the directory
        to save these checkpoints of the model.

        Args:
            runner (Runner): The runner of the training process.
        """
        if self.out_dir is None:
            self.out_dir = runner.work_dir

        # If self.file_client_args is None, self.file_client will not
        # used in CheckpointHook. To avoid breaking backward compatibility,
        # it will not be removed util the release of MMEngine1.0
        self.file_client = FileClient.infer_client(self.file_client_args,
                                                   self.out_dir)

        if self.file_client_args is None:
            self.file_backend = get_file_backend(
                self.out_dir, backend_args=self.backend_args)
        else:
            self.file_backend = self.file_client

        # if `self.out_dir` is not equal to `runner.work_dir`, it means that
        # `self.out_dir` is set so the final `self.out_dir` is the
        # concatenation of `self.out_dir` and the last level directory of
        # `runner.work_dir`
        if self.out_dir != runner.work_dir:
            basename = osp.basename(runner.work_dir.rstrip(osp.sep))
            self.out_dir = self.file_backend.join_path(
                self.out_dir, basename)  # type: ignore  # noqa: E501

        runner.logger.info(f'Checkpoints will be saved to {self.out_dir}.')

        if self.save_best is not None:
            if len(self.key_indicators) == 1:
                if 'best_ckpt' not in runner.message_hub.runtime_info:
                    self.best_ckpt_path = None
                else:
                    self.best_ckpt_path = runner.message_hub.get_info(
                        'best_ckpt')
            else:
                for key_indicator in self.key_indicators:
                    best_ckpt_name = f'best_ckpt_{key_indicator}'
                    if best_ckpt_name not in runner.message_hub.runtime_info:
                        self.best_ckpt_path_dict[key_indicator] = None
                    else:
                        self.best_ckpt_path_dict[
                            key_indicator] = runner.message_hub.get_info(
                                best_ckpt_name)

        if self.max_keep_ckpts > 0:
            keep_ckpt_ids = []
            if 'keep_ckpt_ids' in runner.message_hub.runtime_info:
                keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids')

                while len(keep_ckpt_ids) > self.max_keep_ckpts:
                    step = keep_ckpt_ids.pop(0)
                    if is_main_process():
                        path = self.file_backend.join_path(
                            self.out_dir, self.filename_tmpl.format(step))
                        if self.file_backend.isfile(path):
                            self.file_backend.remove(path)
                        elif self.file_backend.isdir(path):
                            # checkpoints saved by deepspeed are directories
                            self.file_backend.rmtree(path)

            self.keep_ckpt_ids: deque = deque(keep_ckpt_ids,
                                              self.max_keep_ckpts)

    def after_train_epoch(self, runner) -> None:
        """Save the checkpoint and synchronize buffers after each epoch.

        Args:
            runner (Runner): The runner of the training process.
        """
        if not self.by_epoch:
            return

        # save checkpoint for following cases:
        # 1. every ``self.interval`` epochs which start at ``self.save_begin``
        # 2. reach the last epoch of training
        if self.every_n_epochs(runner, self.interval, self.save_begin) or (
                self.save_last and self.is_last_train_epoch(runner)):
            runner.logger.info(
                f'Saving checkpoint at {runner.epoch + 1} epochs')
            self._save_checkpoint(runner)

    def after_val_epoch(self, runner, metrics):
        """Save the checkpoint and synchronize buffers after each evaluation
        epoch.

        Args:
            runner (Runner): The runner of the training process.
            metrics (dict): Evaluation results of all metrics
        """
        if len(metrics) == 0:
            runner.logger.warning(
                'Since `metrics` is an empty dict, the behavior to save '
                'the best checkpoint will be skipped in this evaluation.')
            return

        self._save_best_checkpoint(runner, metrics)

    def after_train(self, runner) -> None:
        """Publish the checkpoint after training.

        Args:
            runner (Runner): The runner of the training process.
        """
        if self.published_keys is None:
            return

        if self.save_last and self.last_ckpt is not None:
            self._publish_model(runner, self.last_ckpt)

        if getattr(self, 'best_ckpt_path', None) is not None:
            self._publish_model(runner, str(self.best_ckpt_path))
        if getattr(self, 'best_ckpt_path_dict', None) is not None:
            for best_ckpt in self.best_ckpt_path_dict.values():
                self._publish_model(runner, best_ckpt)

    @master_only
    def _publish_model(self, runner, ckpt_path: str) -> None:
        """Remove unnecessary keys from ckpt_path and save the new checkpoint.

        Args:
            runner (Runner): The runner of the training process.
            ckpt_path (str): The checkpoint path that ought to be published.
        """
        from mmengine.runner import save_checkpoint
        from mmengine.runner.checkpoint import _load_checkpoint
        checkpoint = _load_checkpoint(ckpt_path)
        assert self.published_keys is not None
        removed_keys = []
        for key in list(checkpoint.keys()):
            if key not in self.published_keys:
                removed_keys.append(key)
                checkpoint.pop(key)
        if removed_keys:
            print_log(
                f'Key {removed_keys} will be removed because they are not '
                'found in published_keys. If you want to keep them, '
                f'please set `{removed_keys}` in published_keys',
                logger='current')
        checkpoint_data = pickle.dumps(checkpoint)
        sha = hashlib.sha256(checkpoint_data).hexdigest()
        final_path = osp.splitext(ckpt_path)[0] + f'-{sha[:8]}.pth'
        save_checkpoint(checkpoint, final_path)
        print_log(
            f'The checkpoint ({ckpt_path}) is published to '
            f'{final_path}.',
            logger='current')

    def _save_checkpoint_with_step(self, runner, step, meta):
        # remove other checkpoints before save checkpoint to make the
        # self.keep_ckpt_ids are saved as expected
        if self.max_keep_ckpts > 0:
            # _save_checkpoint and _save_best_checkpoint may call this
            # _save_checkpoint_with_step in one epoch
            if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step:
                pass
            else:
                if len(self.keep_ckpt_ids) == self.max_keep_ckpts:
                    _step = self.keep_ckpt_ids.popleft()
                    if is_main_process():
                        ckpt_path = self.file_backend.join_path(
                            self.out_dir, self.filename_tmpl.format(_step))

                        if self.file_backend.isfile(ckpt_path):
                            self.file_backend.remove(ckpt_path)
                        elif self.file_backend.isdir(ckpt_path):
                            # checkpoints saved by deepspeed are directories
                            self.file_backend.rmtree(ckpt_path)

                self.keep_ckpt_ids.append(step)
                runner.message_hub.update_info('keep_ckpt_ids',
                                               list(self.keep_ckpt_ids))

        ckpt_filename = self.filename_tmpl.format(step)
        self.last_ckpt = self.file_backend.join_path(self.out_dir,
                                                     ckpt_filename)
        runner.message_hub.update_info('last_ckpt', self.last_ckpt)

        runner.save_checkpoint(
            self.out_dir,
            ckpt_filename,
            self.file_client_args,
            save_optimizer=self.save_optimizer,
            save_param_scheduler=self.save_param_scheduler,
            meta=meta,
            by_epoch=self.by_epoch,
            backend_args=self.backend_args,
            **self.args)

        # Model parallel-like training should involve pulling sharded states
        # from all ranks, but skip the following procedure.
        if not is_main_process():
            return

        save_file = osp.join(runner.work_dir, 'last_checkpoint')
        with open(save_file, 'w') as f:
            f.write(self.last_ckpt)  # type: ignore

    def _save_checkpoint(self, runner) -> None:
        """Save the current checkpoint and delete outdated checkpoint.

        Args:
            runner (Runner): The runner of the training process.
        """
        if self.by_epoch:
            step = runner.epoch + 1
            meta = dict(epoch=step, iter=runner.iter)
        else:
            step = runner.iter + 1
            meta = dict(epoch=runner.epoch, iter=step)

        self._save_checkpoint_with_step(runner, step, meta=meta)

    def _save_best_checkpoint(self, runner, metrics) -> None:
        """Save the current checkpoint and delete outdated checkpoint.

        Args:
            runner (Runner): The runner of the training process.
            metrics (dict): Evaluation results of all metrics.
        """
        if not self.save_best:
            return

        if self.by_epoch:
            ckpt_filename = self.filename_tmpl.format(runner.epoch)
            cur_type, cur_time = 'epoch', runner.epoch
        else:
            ckpt_filename = self.filename_tmpl.format(runner.iter)
            cur_type, cur_time = 'iter', runner.iter

        meta = dict(epoch=runner.epoch, iter=runner.iter)

        # handle auto in self.key_indicators and self.rules before the loop
        if 'auto' in self.key_indicators:
            self._init_rule(self.rules, [list(metrics.keys())[0]])

        best_ckpt_updated = False
        # save best logic
        # get score from messagehub
        for key_indicator, rule in zip(self.key_indicators, self.rules):
            key_score = metrics[key_indicator]

            if len(self.key_indicators) == 1:
                best_score_key = 'best_score'
                runtime_best_ckpt_key = 'best_ckpt'
                best_ckpt_path = self.best_ckpt_path
            else:
                best_score_key = f'best_score_{key_indicator}'
                runtime_best_ckpt_key = f'best_ckpt_{key_indicator}'
                best_ckpt_path = self.best_ckpt_path_dict[key_indicator]

            if best_score_key not in runner.message_hub.runtime_info:
                best_score = self.init_value_map[rule]
            else:
                best_score = runner.message_hub.get_info(best_score_key)

            if key_score is None or not self.is_better_than[key_indicator](
                    key_score, best_score):
                continue

            best_ckpt_updated = True

            best_score = key_score
            runner.message_hub.update_info(best_score_key, best_score)

            if best_ckpt_path and is_main_process():
                is_removed = False
                if self.file_backend.isfile(best_ckpt_path):
                    self.file_backend.remove(best_ckpt_path)
                    is_removed = True
                elif self.file_backend.isdir(best_ckpt_path):
                    # checkpoints saved by deepspeed are directories
                    self.file_backend.rmtree(best_ckpt_path)
                    is_removed = True

                if is_removed:
                    runner.logger.info(
                        f'The previous best checkpoint {best_ckpt_path} '
                        'is removed')

            best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}'
            # Replace illegal characters for filename with `_`
            best_ckpt_name = best_ckpt_name.replace('/', '_')
            if len(self.key_indicators) == 1:
                self.best_ckpt_path = self.file_backend.join_path(  # type: ignore # noqa: E501
                    self.out_dir, best_ckpt_name)
                runner.message_hub.update_info(runtime_best_ckpt_key,
                                               self.best_ckpt_path)
            else:
                self.best_ckpt_path_dict[
                    key_indicator] = self.file_backend.join_path(  # type: ignore # noqa: E501
                        self.out_dir, best_ckpt_name)
                runner.message_hub.update_info(
                    runtime_best_ckpt_key,
                    self.best_ckpt_path_dict[key_indicator])
            runner.save_checkpoint(
                self.out_dir,
                filename=best_ckpt_name,
                file_client_args=self.file_client_args,
                save_optimizer=False,
                save_param_scheduler=False,
                meta=meta,
                by_epoch=False,
                backend_args=self.backend_args)
            runner.logger.info(
                f'The best checkpoint with {best_score:0.4f} {key_indicator} '
                f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')

        # save checkpoint again to update the best_score and best_ckpt stored
        # in message_hub because the checkpoint saved in `after_train_epoch`
        # or `after_train_iter` stage only keep the previous best checkpoint
        # not the current best checkpoint which causes the current best
        # checkpoint can not be removed when resuming training.
        if best_ckpt_updated and self.last_ckpt is not None:
            self._save_checkpoint_with_step(runner, cur_time, meta)

    def _init_rule(self, rules, key_indicators) -> None:
        """Initialize rule, key_indicator, comparison_func, and best score. If
        key_indicator is a list of string and rule is a string, all metric in
        the key_indicator will share the same rule.

        Here is the rule to determine which rule is used for key indicator when
        the rule is not specific (note that the key indicator matching is case-
        insensitive):

        1. If the key indicator is in ``self.greater_keys``, the rule
            will be specified as 'greater'.
        2. Or if the key indicator is in ``self.less_keys``, the rule
            will be specified as 'less'.
        3. Or if any one item in ``self.greater_keys`` is a substring of
            key_indicator, the rule will be specified as 'greater'.
        4. Or if any one item in ``self.less_keys`` is a substring of
            key_indicator, the rule will be specified as 'less'.

        Args:
            rule (List[Optional[str]]): Comparison rule for best score.
            key_indicator (List[str]): Key indicator to determine
                the comparison rule.
        """
        if len(rules) == 1:
            rules = rules * len(key_indicators)

        self.rules = []
        for rule, key_indicator in zip(rules, key_indicators):

            if rule not in self.rule_map and rule is not None:
                raise KeyError('rule must be greater, less or None, '
                               f'but got {rule}.')

            if rule is None and key_indicator != 'auto':
                # `_lc` here means we use the lower case of keys for
                # case-insensitive matching
                key_indicator_lc = key_indicator.lower()
                greater_keys = {key.lower() for key in self.greater_keys}
                less_keys = {key.lower() for key in self.less_keys}

                if key_indicator_lc in greater_keys:
                    rule = 'greater'
                elif key_indicator_lc in less_keys:
                    rule = 'less'
                elif any(key in key_indicator_lc for key in greater_keys):
                    rule = 'greater'
                elif any(key in key_indicator_lc for key in less_keys):
                    rule = 'less'
                else:
                    raise ValueError('Cannot infer the rule for key '
                                     f'{key_indicator}, thus a specific rule '
                                     'must be specified.')
            if rule is not None:
                self.is_better_than[key_indicator] = self.rule_map[rule]
            self.rules.append(rule)

        self.key_indicators = key_indicators

    def after_train_iter(self,
                         runner,
                         batch_idx: int,
                         data_batch: DATA_BATCH = None,
                         outputs=Optional[dict]) -> None:
        """Save the checkpoint and synchronize buffers after each iteration.

        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict or tuple or list, optional): Data from dataloader.
            outputs (dict, optional): Outputs from model.
        """
        if self.by_epoch:
            return

        # save checkpoint for following cases:
        # 1. every ``self.interval`` iterations
        #       which start at ``self.save_begin``
        # 2. reach the last iteration of training
        if self.every_n_train_iters(runner, self.interval,
                                    self.save_begin) or \
                (self.save_last and
                 self.is_last_train_iter(runner)):
            runner.logger.info(
                f'Saving checkpoint at {runner.iter + 1} iterations')
            self._save_checkpoint(runner)

用户可以将max_keep_ckpts设置为只保存少量的检查点,或者通过save_optimizer来决定是否存储优化器的状态字典。

CheckpointHook
#定期保存检查点。
mmcv.runner.CheckpointHook(interval=- 1, by_epoch=True, save_optimizer=True, out_dir=None, max_keep_ckpts=- 1, save_last=True, sync_buffer=False, file_client_args=None, **kwargs)
Parameters
interval (int) – The saving period. If by_epoch=True, interval indicates epochs, otherwise it indicates iterations. Default: -1, which means “never”.

by_epoch (bool) – Saving checkpoints by epoch or by iteration. Default: True.

save_optimizer (bool) – Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. Default: True.

out_dir (str, optional) – The root directory to save checkpoints. If not specified, runner.work_dir will be used by default. If specified, the out_dir will be the concatenation of out_dir and the last level directory of runner.work_dir. Changed in version 1.3.16.

max_keep_ckpts (int, optional) – The maximum checkpoints to keep. In some cases we want only the latest few checkpoints and would like to delete old ones to save the disk space. Default: -1, which means unlimited.

save_last (bool, optional) – Whether to force the last checkpoint to be saved regardless of interval. Default: True.

sync_buffer (bool, optional) – Whether to synchronize buffers in different gpus. Default: False.

file_client_args (dict, optional) – Arguments to instantiate a FileClient. See mmcv.fileio.FileClient for details. Default: None. New in version 1.3.16.

注意:
在v1.3.16之前,out_dir参数表示检查点存储的路径。但是,从v1.3.16开始,out_dir表示根目录,保存检查点的最终路径是out_dir和runner.work_dir的最后一级目录的连接。假设out_dir的值为“/path/of/A”,runner的值为。work dir为“/path/of/B”,则最终路径为“/path/of/A/B”。

2、 Log config

log_config封装了多个日志记录器钩子,并允许设置间隔。现在MMCV支持WandbLoggerHook, MlflowLoggerHook和TensorboardLoggerHook。详细用法可以在文档中找到LoggerHook

log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        dict(type='TensorboardLoggerHook')
    ])
LoggerHook
CLASSmmcv.runner.LoggerHook(interval=10, ignore_last=True, reset_flag=False, by_epoch=True)[SOURCE]
Base class for logger hooks.

在这里插入图片描述

3、 Evaluation config

评估的配置将用于初始化EvalHook。除键间隔外,其他参数(如metric)将被传递给数据集.evaluate ()

evaluation = dict(interval=1, metric='bbox')
mport os.path as osp
import warnings

from mmcv.runner import Hook
from torch.utils.data import DataLoader


class EvalHook(Hook):
    """Evaluation hook.
    Notes:
        If new arguments are added for EvalHook, tools/test.py may be
    effected.
    Attributes:
        dataloader (DataLoader): A PyTorch dataloader.
        start (int, optional): Evaluation starting epoch. It enables evaluation
            before the training starts if ``start`` <= the resuming epoch.
            If None, whether to evaluate is merely decided by ``interval``.
            Default: None.
        interval (int): Evaluation interval (by epochs). Default: 1.
        **eval_kwargs: Evaluation arguments fed into the evaluate function of
            the dataset.
    """

    def __init__(self, dataloader, start=None, interval=1, **eval_kwargs):
        if not isinstance(dataloader, DataLoader):
            raise TypeError('dataloader must be a pytorch DataLoader, but got'
                            f' {type(dataloader)}')
        if not interval > 0:
            raise ValueError(f'interval must be positive, but got {interval}')
        if start is not None and start < 0:
            warnings.warn(
                f'The evaluation start epoch {start} is smaller than 0, '
                f'use 0 instead', UserWarning)
            start = 0
        self.dataloader = dataloader
        self.interval = interval
        self.start = start
        self.eval_kwargs = eval_kwargs
        self.initial_epoch_flag = True

    def before_train_epoch(self, runner):
        """Evaluate the model only at the start of training."""
        if not self.initial_epoch_flag:
            return
        if self.start is not None and runner.epoch >= self.start:
            self.after_train_epoch(runner)
        self.initial_epoch_flag = False

    def evaluation_flag(self, runner):
        """Judge whether to perform_evaluation after this epoch.
        Returns:
            bool: The flag indicating whether to perform evaluation.
        """
        if self.start is None:
            if not self.every_n_epochs(runner, self.interval):
                # No evaluation during the interval epochs.
                return False
        elif (runner.epoch + 1) < self.start:
            # No evaluation if start is larger than the current epoch.
            return False
        else:
            # Evaluation only at epochs 3, 5, 7... if start==3 and interval==2
            if (runner.epoch + 1 - self.start) % self.interval:
                return False
        return True

    def after_train_epoch(self, runner):
        if not self.evaluation_flag(runner):
            return
        from mmdet.apis import single_gpu_test
        results = single_gpu_test(runner.model, self.dataloader, show=False)
        self.evaluate(runner, results)

    def evaluate(self, runner, results):
        eval_res = self.dataloader.dataset.evaluate(
            results, logger=runner.logger, **self.eval_kwargs)
        for name, val in eval_res.items():
            runner.log_buffer.output[name] = val
        runner.log_buffer.ready = True


class DistEvalHook(EvalHook):
    """Distributed evaluation hook.
    Notes:
        If new arguments are added, tools/test.py may be effected.
    Attributes:
        dataloader (DataLoader): A PyTorch dataloader.
        start (int, optional): Evaluation starting epoch. It enables evaluation
            before the training starts if ``start`` <= the resuming epoch.
            If None, whether to evaluate is merely decided by ``interval``.
            Default: None.
        interval (int): Evaluation interval (by epochs). Default: 1.
        tmpdir (str | None): Temporary directory to save the results of all
            processes. Default: None.
        gpu_collect (bool): Whether to use gpu or cpu to collect results.
            Default: False.
        **eval_kwargs: Evaluation arguments fed into the evaluate function of
            the dataset.
    """

    def __init__(self,
                 dataloader,
                 start=None,
                 interval=1,
                 tmpdir=None,
                 gpu_collect=False,
                 **eval_kwargs):
        super().__init__(
            dataloader, start=start, interval=interval, **eval_kwargs)
        self.tmpdir = tmpdir
        self.gpu_collect = gpu_collect

    def after_train_epoch(self, runner):
        if not self.evaluation_flag(runner):
            return
        from mmdet.apis import multi_gpu_test
        tmpdir = self.tmpdir
        if tmpdir is None:
            tmpdir = osp.join(runner.work_dir, '.eval_hook')
        results = multi_gpu_test(
            runner.model,
            self.dataloader,
            tmpdir=tmpdir,
            gpu_collect=self.gpu_collect)
        if runner.rank == 0:
            print('\n')
            self.evaluate(runner, results)

2、Pytorch支持的自定义优化器

我们已经支持使用所有由PyTorch实现的优化器,唯一的修改就是更改配置文件的优化器字段。例如,如果您想要使用ADAM(注意性能可能会下降很多),修改可以如下所示。

optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)

要修改模型的学习率,用户只需修改optimizer配置中的1r即可。用户可以直接在PyTorch的API文档添加链接描述后面设置参数。

2、定制self-implemented优化器

1、Define a new optimizer(定义一个新的优化器)

一个定制的优化器可以定义如下。假设您想添加一个名为MyOptimizer的优化器,它有参数a、b和c。您需要创建一个名为mmdet/core/optimizer的新目录。然后在文件中实现新的优化器,例如在mmdet/core/optimizer/my_optimizer.py中:

from .registry import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module()
class MyOptimizer(Optimizer):

    def __init__(self, a, b, c)
2. 将优化器添加到注册表

要找到上面定义的模块,首先应该将该模块导入主命名空间。实现这一目标有两种选择。
修改mmdet/core/optimizer/init.py来导入它。新定义的模块应该导入到mmdet/core/optimizer/init.py中,这样注册表就会找到新模块并添加它:

from .my_optimizer import MyOptimizer

Use custom_imports in the config to manually import it (使用配置中的custom_imports手动导入它)

custom_imports = dict(imports=['mmdet.core.optimizer.my_optimizer'], allow_failed_imports=False)

The module mmdet.core.optimizer.my_optimizer will be imported at the beginning of the program and the class MyOptimizer is then automatically registered. Note that only the package containing the class MyOptimizer should be imported. mmdet.core.optimizer.my_optimizer.MyOptimizer cannot be imported directly.
实际上,用户可以使用这种导入方法使用完全不同的文件目录结构,只要模块根可以位于PYTHONPATH。

3.在配置文件中指定优化器

然后你可以在配置文件的优化器字段中使用MyOptimizer。在配置文件中,优化器由字段优化器定义,如下所示:

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)

To use your own optimizer, the field can be changed to

optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)

2、Customize optimizer constructor (定制优化器的构造函数)

一些模型可能有一些参数特定的设置来优化,例如BatchNorm层的权重衰减。用户可以通过定制优化器构造函数来实现这些细粒度的参数调整。

from mmcv.utils import build_from_cfg

from mmcv.runner.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS
from mmdet.utils import get_root_logger
from .my_optimizer import MyOptimizer


@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(object):

    def __init__(self, optimizer_cfg, paramwise_cfg=None):

    def __call__(self, model):

        return my_optimizer

这里实现了默认的优化器构造函数default_constructor.py ,它也可以作为新的优化器构造函数的模板。

import warnings

import torch
from torch.nn import GroupNorm, LayerNorm

from mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS


@OPTIMIZER_BUILDERS.register_module()
class DefaultOptimizerConstructor:
    """Default constructor for optimizers.
    By default each parameter share the same optimizer settings, and we
    provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
    It is a dict and may contain the following fields:
    - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
      one of the keys in ``custom_keys`` is a substring of the name of one
      parameter, then the setting of the parameter will be specified by
      ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
      be ignored. It should be noted that the aforementioned ``key`` is the
      longest key that is a substring of the name of the parameter. If there
      are multiple matched keys with the same length, then the key with lower
      alphabet order will be chosen.
      ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
      and ``decay_mult``. See Example 2 below.
    - ``bias_lr_mult`` (float): It will be multiplied to the learning
      rate for all bias parameters (except for those in normalization
      layers).
    - ``bias_decay_mult`` (float): It will be multiplied to the weight
      decay for all bias parameters (except for those in
      normalization layers and depthwise conv layers).
    - ``norm_decay_mult`` (float): It will be multiplied to the weight
      decay for all weight and bias parameters of normalization
      layers.
    - ``dwconv_decay_mult`` (float): It will be multiplied to the weight
      decay for all weight and bias parameters of depthwise conv
      layers.
    - ``bypass_duplicate`` (bool): If true, the duplicate parameters
      would not be added into optimizer. Default: False.
    Args:
        model (:obj:`nn.Module`): The model with parameters to be optimized.
        optimizer_cfg (dict): The config dict of the optimizer.
            Positional fields are
                - `type`: class name of the optimizer.
            Optional fields are
                - any arguments of the corresponding optimizer type, e.g.,
                  lr, weight_decay, momentum, etc.
        paramwise_cfg (dict, optional): Parameter-wise options.
    Example 1:
        >>> model = torch.nn.modules.Conv1d(1, 1, 1)
        >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
        >>>                      weight_decay=0.0001)
        >>> paramwise_cfg = dict(norm_decay_mult=0.)
        >>> optim_builder = DefaultOptimizerConstructor(
        >>>     optimizer_cfg, paramwise_cfg)
        >>> optimizer = optim_builder(model)
    Example 2:
        >>> # assume model have attribute model.backbone and model.cls_head
        >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
        >>> paramwise_cfg = dict(custom_keys={
                '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
        >>> optim_builder = DefaultOptimizerConstructor(
        >>>     optimizer_cfg, paramwise_cfg)
        >>> optimizer = optim_builder(model)
        >>> # Then the `lr` and `weight_decay` for model.backbone is
        >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
        >>> # model.cls_head is (0.01, 0.95).
    """

    def __init__(self, optimizer_cfg, paramwise_cfg=None):
        if not isinstance(optimizer_cfg, dict):
            raise TypeError('optimizer_cfg should be a dict',
                            f'but got {type(optimizer_cfg)}')
        self.optimizer_cfg = optimizer_cfg
        self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
        self.base_lr = optimizer_cfg.get('lr', None)
        self.base_wd = optimizer_cfg.get('weight_decay', None)
        self._validate_cfg()

    def _validate_cfg(self):
        if not isinstance(self.paramwise_cfg, dict):
            raise TypeError('paramwise_cfg should be None or a dict, '
                            f'but got {type(self.paramwise_cfg)}')

        if 'custom_keys' in self.paramwise_cfg:
            if not isinstance(self.paramwise_cfg['custom_keys'], dict):
                raise TypeError(
                    'If specified, custom_keys must be a dict, '
                    f'but got {type(self.paramwise_cfg["custom_keys"])}')
            if self.base_wd is None:
                for key in self.paramwise_cfg['custom_keys']:
                    if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]:
                        raise ValueError('base_wd should not be None')

        # get base lr and weight decay
        # weight_decay must be explicitly specified if mult is specified
        if ('bias_decay_mult' in self.paramwise_cfg
                or 'norm_decay_mult' in self.paramwise_cfg
                or 'dwconv_decay_mult' in self.paramwise_cfg):
            if self.base_wd is None:
                raise ValueError('base_wd should not be None')

    def _is_in(self, param_group, param_group_list):
        assert is_list_of(param_group_list, dict)
        param = set(param_group['params'])
        param_set = set()
        for group in param_group_list:
            param_set.update(set(group['params']))

        return not param.isdisjoint(param_set)

    def add_params(self, params, module, prefix=''):
        """Add all parameters of module to the params list.
        The parameters of the given module will be added to the list of param
        groups, with specific rules defined by paramwise_cfg.
        Args:
            params (list[dict]): A list of param groups, it will be modified
                in place.
            module (nn.Module): The module to be added.
            prefix (str): The prefix of the module
        """
        # get param-wise options
        custom_keys = self.paramwise_cfg.get('custom_keys', {})
        # first sort with alphabet order and then sort with reversed len of str
        sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)

        bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
        bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
        norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
        dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
        bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)

        # special rules for norm layers and depth-wise conv layers
        is_norm = isinstance(module,
                             (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
        is_dwconv = (
            isinstance(module, torch.nn.Conv2d)
            and module.in_channels == module.groups)

        for name, param in module.named_parameters(recurse=False):
            param_group = {'params': [param]}
            if not param.requires_grad:
                params.append(param_group)
                continue
            if bypass_duplicate and self._is_in(param_group, params):
                warnings.warn(f'{prefix} is duplicate. It is skipped since '
                              f'bypass_duplicate={bypass_duplicate}')
                continue
            # if the parameter match one of the custom keys, ignore other rules
            is_custom = False
            for key in sorted_keys:
                if key in f'{prefix}.{name}':
                    is_custom = True
                    lr_mult = custom_keys[key].get('lr_mult', 1.)
                    param_group['lr'] = self.base_lr * lr_mult
                    if self.base_wd is not None:
                        decay_mult = custom_keys[key].get('decay_mult', 1.)
                        param_group['weight_decay'] = self.base_wd * decay_mult
                    break
            if not is_custom:
                # bias_lr_mult affects all bias parameters except for norm.bias
                if name == 'bias' and not is_norm:
                    param_group['lr'] = self.base_lr * bias_lr_mult
                # apply weight decay policies
                if self.base_wd is not None:
                    # norm decay
                    if is_norm:
                        param_group[
                            'weight_decay'] = self.base_wd * norm_decay_mult
                    # depth-wise conv
                    elif is_dwconv:
                        param_group[
                            'weight_decay'] = self.base_wd * dwconv_decay_mult
                    # bias lr and decay
                    elif name == 'bias':
                        param_group[
                            'weight_decay'] = self.base_wd * bias_decay_mult
            params.append(param_group)

        for child_name, child_mod in module.named_children():
            child_prefix = f'{prefix}.{child_name}' if prefix else child_name
            self.add_params(params, child_mod, prefix=child_prefix)

    def __call__(self, model):
        if hasattr(model, 'module'):
            model = model.module

        optimizer_cfg = self.optimizer_cfg.copy()
        # if no paramwise option is specified, just use the global setting
        if not self.paramwise_cfg:
            optimizer_cfg['params'] = model.parameters()
            return build_from_cfg(optimizer_cfg, OPTIMIZERS)

        # set param-wise lr and weight decay recursively
        params = []
        self.add_params(params, model)
        optimizer_cfg['params'] = params

        return build_from_cfg(optimizer_cfg, OPTIMIZERS)

在这里插入图片描述

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

from .builder import RUNNER_BUILDERS, RUNNERS


@RUNNER_BUILDERS.register_module()
class DefaultRunnerConstructor:
    """Default constructor for runners.
    Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
    For example, We can inject some new properties and functions for `Runner`.
    Example:
        >>> from mmcv.runner import RUNNER_BUILDERS, build_runner
        >>> # Define a new RunnerReconstructor
        >>> @RUNNER_BUILDERS.register_module()
        >>> class MyRunnerConstructor:
        ...     def __init__(self, runner_cfg, default_args=None):
        ...         if not isinstance(runner_cfg, dict):
        ...             raise TypeError('runner_cfg should be a dict',
        ...                             f'but got {type(runner_cfg)}')
        ...         self.runner_cfg = runner_cfg
        ...         self.default_args = default_args
        ...
        ...     def __call__(self):
        ...         runner = RUNNERS.build(self.runner_cfg,
        ...                                default_args=self.default_args)
        ...         # Add new properties for existing runner
        ...         runner.my_name = 'my_runner'
        ...         runner.my_function = lambda self: print(self.my_name)
        ...         ...
        >>> # build your runner
        >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
        ...                   constructor='MyRunnerConstructor')
        >>> runner = build_runner(runner_cfg)
    """

    def __init__(self, runner_cfg: dict, default_args: Optional[dict] = None):
        if not isinstance(runner_cfg, dict):
            raise TypeError('runner_cfg should be a dict',
                            f'but got {type(runner_cfg)}')
        self.runner_cfg = runner_cfg
        self.default_args = default_args

    def __call__(self):
        return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
1、案例分析(Layer_decay_optimizer_constructor.py)

在这里插入图片描述

1、构造builder.py优化器函数
# Copyright (c) OpenMMLab. All rights reserved.
import copy

from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
from mmcv.utils import Registry, build_from_cfg

OPTIMIZER_BUILDERS = Registry(
    'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)


def build_optimizer_constructor(cfg):
    constructor_type = cfg.get('type')
    if constructor_type in OPTIMIZER_BUILDERS:
        return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
    elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
        return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
    else:
        raise KeyError(f'{constructor_type} is not registered '
                       'in the optimizer builder registry.')


def build_optimizer(model, cfg):
    optimizer_cfg = copy.deepcopy(cfg)
    constructor_type = optimizer_cfg.pop('constructor',
                                         'DefaultOptimizerConstructor')
    paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
    optim_constructor = build_optimizer_constructor(
        dict(
            type=constructor_type,
            optimizer_cfg=optimizer_cfg,
            paramwise_cfg=paramwise_cfg))
    optimizer = optim_constructor(model)
    return 
2、构造Layer_decay_optimizer_constructor.py优化器函数
# Copyright (c) OpenMMLab. All rights reserved.
import json
#1、首先导入了DefaultOptimizerConstructorz这个函数
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
from mmdet.utils import get_root_logger
from .builder import OPTIMIZER_BUILDERS


def get_layer_id_for_convnext(var_name, max_layer_id):
    """Get the layer id to set the different learning rates in ``layer_wise``
    decay_type.
    Args:
        var_name (str): The key of the model.
        max_layer_id (int): Maximum layer id.
    Returns:
        int: The id number corresponding to different learning rate in
        ``LearningRateDecayOptimizerConstructor``.
    """

    if var_name in ('backbone.cls_token', 'backbone.mask_token',
                    'backbone.pos_embed'):
        return 0
    elif var_name.startswith('backbone.downsample_layers'):
        stage_id = int(var_name.split('.')[2])
        if stage_id == 0:
            layer_id = 0
        elif stage_id == 1:
            layer_id = 2
        elif stage_id == 2:
            layer_id = 3
        elif stage_id == 3:
            layer_id = max_layer_id
        return layer_id
    elif var_name.startswith('backbone.stages'):
        stage_id = int(var_name.split('.')[2])
        block_id = int(var_name.split('.')[3])
        if stage_id == 0:
            layer_id = 1
        elif stage_id == 1:
            layer_id = 2
        elif stage_id == 2:
            layer_id = 3 + block_id // 3
        elif stage_id == 3:
            layer_id = max_layer_id
        return layer_id
    else:
        return max_layer_id + 1


def get_stage_id_for_convnext(var_name, max_stage_id):
    """Get the stage id to set the different learning rates in ``stage_wise``
    decay_type.
    Args:
        var_name (str): The key of the model.
        max_stage_id (int): Maximum stage id.
    Returns:
        int: The id number corresponding to different learning rate in
        ``LearningRateDecayOptimizerConstructor``.
    """

    if var_name in ('backbone.cls_token', 'backbone.mask_token',
                    'backbone.pos_embed'):
        return 0
    elif var_name.startswith('backbone.downsample_layers'):
        return 0
    elif var_name.startswith('backbone.stages'):
        stage_id = int(var_name.split('.')[2])
        return stage_id + 1
    else:
        return max_stage_id - 1

#构造了LearningRateDecayOptimizerConstructor这个类
@OPTIMIZER_BUILDERS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
    # Different learning rates are set for different layers of backbone.
    # Note: Currently, this optimizer constructor is built for ConvNeXt.

    def add_params(self, params, module, **kwargs):
        """Add all parameters of module to the params list.
        The parameters of the given module will be added to the list of param
        groups, with specific rules defined by paramwise_cfg.
        Args:
            params (list[dict]): A list of param groups, it will be modified
                in place.
            module (nn.Module): The module to be added.
        """
        logger = get_root_logger()

        parameter_groups = {}
        logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
        num_layers = self.paramwise_cfg.get('num_layers') + 2
        decay_rate = self.paramwise_cfg.get('decay_rate')
        decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
        logger.info('Build LearningRateDecayOptimizerConstructor  '
                    f'{decay_type} {decay_rate} - {num_layers}')
        weight_decay = self.base_wd
        for name, param in module.named_parameters():
            if not param.requires_grad:
                continue  # frozen weights
            if len(param.shape) == 1 or name.endswith('.bias') or name in (
                    'pos_embed', 'cls_token'):
                group_name = 'no_decay'
                this_weight_decay = 0.
            else:
                group_name = 'decay'
                this_weight_decay = weight_decay
            if 'layer_wise' in decay_type:
                if 'ConvNeXt' in module.backbone.__class__.__name__:
                    layer_id = get_layer_id_for_convnext(
                        name, self.paramwise_cfg.get('num_layers'))
                    logger.info(f'set param {name} as id {layer_id}')
                else:
                    raise NotImplementedError()
            elif decay_type == 'stage_wise':
                if 'ConvNeXt' in module.backbone.__class__.__name__:
                    layer_id = get_stage_id_for_convnext(name, num_layers)
                    logger.info(f'set param {name} as id {layer_id}')
                else:
                    raise NotImplementedError()
            group_name = f'layer_{layer_id}_{group_name}'

            if group_name not in parameter_groups:
                scale = decay_rate**(num_layers - layer_id - 1)

                parameter_groups[group_name] = {
                    'weight_decay': this_weight_decay,
                    'params': [],
                    'param_names': [],
                    'lr_scale': scale,
                    'group_name': group_name,
                    'lr': scale * self.base_lr,
                }

            parameter_groups[group_name]['params'].append(param)
            parameter_groups[group_name]['param_names'].append(name)
        rank, _ = get_dist_info()
        if rank == 0:
            to_display = {}
            for key in parameter_groups:
                to_display[key] = {
                    'param_names': parameter_groups[key]['param_names'],
                    'lr_scale': parameter_groups[key]['lr_scale'],
                    'lr': parameter_groups[key]['lr'],
                    'weight_decay': parameter_groups[key]['weight_decay'],
                }
            logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
        params.extend(parameter_groups.values())
3、将构造函数导入__init__.py
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import OPTIMIZER_BUILDERS, build_optimizer
from .layer_decay_optimizer_constructor import \
    LearningRateDecayOptimizerConstructor

__all__ = [
    'LearningRateDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS',
    'build_optimizer'
]
4、在配置文件中的使用
optimizer = dict(
    _delete_=True,
    constructor='LearningRateDecayOptimizerConstructor',
    type='AdamW',
    lr=0.0001,
    betas=(0.9, 0.999),
    weight_decay=0.05,
    paramwise_cfg={
        'decay_rate': 0.95,
        'decay_type': 'layer_wise',
        'num_layers': 6
    })

##3、Additional settings

1、 使用梯度剪辑来稳定训练:有些模型需要梯度剪辑来剪辑梯度来稳定训练过程。
optimizer_config = dict(
    _delete_=True, grad_clip=dict(max_norm=35, norm_type=2))

如果您的配置继承了已经设置了optimizer_config的基本配置,您可能需要_delete_=True来覆盖不必要的设置。有关更多细节,请参阅配置文档。

2、 使用动量计划来加速模型收敛:

我们支持动量调度,根据学习速率修改模型的动量,使模型更快地收敛。动量调度器通常与LR调度器一起使用,例如,在三维检测中使用以下配置来加速收敛。更多细节,请参考CyclicLrUpdater和CyclicMomentumUpdater的实现。

lr_config = dict(
    policy='cyclic',
    target_ratio=(10, 1e-4),
    cyclic_times=1,
    step_ratio_up=0.4,
)
momentum_config = dict(
    policy='cyclic',
    target_ratio=(0.85 / 0.95, 1),
    cyclic_times=1,
    step_ratio_up=0.4,
)

3、 自定义训练学习率

默认情况下,我们使用1xschedule的步进学习速率,这在MMCV中调用了StepLRHook。我们支持许多其他的学习速率计划,如余弦退火和Poly计划。下面是一些例子
Poly schedule:

lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)

ConsineAnnealing schedule:

lr_config = dict(
    policy='CosineAnnealing',
    warmup='linear',
    warmup_iters=1000,
    warmup_ratio=1.0 / 10,
    min_lr_ratio=1e-5)

4、 自定义工作流

工作流是指定运行顺序和时间的(阶段,时间)列表。默认情况下,它被设置为

workflow = [('train', 1)]

这意味着要进行1个阶段的训练。有时,用户可能想要检查验证集中模型的一些指标(例如损失、准确性)。在这种情况下,我们可以将工作流设置为

[('train', 1), ('val', 1)]

因此,1个训练阶段和1个验证阶段将反复运行。
1、The parameters of model will not be updated during val epoch.
2、配置中的关键字total_epoch只控制培训epoch的数量,不会影响验证工作流。
3.工作流[(‘train’, 1), (‘val’, 1)]和[(‘train’, 1)]不会改变EvalHook的行为,因为EvalHook是由after_train_epoch调用的,并且验证工作流只影响通过after_val_epoch调用的钩子。因此,[(‘train’, 1), (‘val’, 1)]和[((‘train’, 1)]之间的唯一区别是,跑步者将在每个训练周期后计算验证集上的损失。

  • 6
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值