mmaction2 指南
Outline
7. 自定义Runtime Settings
使用pytorch的优化器
optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)
optimizer = dict(type='Adam', lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
自定义优化器
附加配置
使用gradient clip 让训练过程平稳
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
使用动量momentum加速收敛
CyclicLrUpdater
lr_config = dict(
policy='cyclic',
target_ratio=(10, 1e-4),
cyclic_times=1,
step_ratio_up=0.4,
)
momentum_config = dict(
policy='cyclic',
target_ratio=(0.85 / 0.95, 1),
cyclic_times=1,
step_ratio_up=0.4,
)
自定义训练策略
配置
lr_config = dict(
policy='CosineAnnealing',
warmup='linear',
warmup_iters=1000,
warmup_ratio=1.0 / 10,
min_lr_ratio=1e-5)
自定义工作流 Workflow
每次训练后,使用 EvalHook
做评估,或者用 val
workflow
Workflow is a list of (phase, epochs) to specify the running order and epochs. By default it is set to be
Workdlow是个列表(模式,训练轮),默认是 workflow = [('train', 1)]
训练一轮,验证一轮 [('train', 1), ('val', 1)]
Keyword total_epochs
in the config only controls the number of training epochs and will not affect the validation workflow.
自定义 Hooks
注册钩子,等待触发
For hooks with the same priority, they will be triggered in the same order as they are registered.
1. 实现新的hook
下面hook中需要自定义before_run, after_run …需要干什么
from mmcv.runner import HOOKS, Hook
@HOOKS.register_module()
class MyHook(Hook):
def __init__(self, a, b):
pass
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
2. 注册
-
mmaction/core/utils/__init__.py
加入from .my_hook import MyHook
-
用户自定义hook的通过
custom_imports
配置
custom_imports = dict(imports=['mmaction.core.utils.my_hook'], allow_failed_imports=False)
3. 修改配置
custom_hooks = [
dict(type='MyHook', a=a_value, b=b_value, priority='NORMAL'(可设置可不设置))
]
使用MMCV定义的hooks
mmcv_hooks = [
dict(type='MMCVHook', a=a_value, b=b_value, priority='NORMAL')
]
下面是已经有的 mmcv 自带的 hook
- log_config
- checkpoint_config
- evaluation
- lr_config
- optimizer_config
- momentum_config
logger hook 优先级最低 VERY_LOW,其余都是NORMAL
修改 checkpoint hook
checkpoint_config = dict(interval=1)
修改 log hook
参考 (mmcv.runner.LoggerHook)[https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.LoggerHook]
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
])
修改 evaluation hook
参考 eval_hooks
evaluation = dict(interval=1, metrics='bbox')