先从train.py的main函数看起,main函数总共分为6个部分,分别为:
●解析命令行参数和配置文件
●设置work_dir相关参数
●设置优化器
●设置学习率和断点续训练
●构建网络模型、构建数据集、设置验证集和checkpoint相关参数
●模型训练
1.解析命令行参数和配置文件
命令行参数获取通过args = parse_args()实现,所涉及到参数包括:
参数 | 说明 |
---|---|
config | 配置文件路径 |
–work-dir | 输出文件路径 |
–amp | 实现自动混合精度训练 |
-auto-scale-lr | 启用自动缩放LR |
—resume | 断点续训练权重文件路径 |
–ceph | 使用ceph作为数据存储后端 |
–cfg-options | 覆盖config中的配置 |
–launcher | 分布式训练相关参数 |
–local_rank | 本地进程编号, |
通过 cfg = Config.fromfile(args.config)实现配置文件的导入。目前mmdetecton3D已实现模型的配置文件都放在configs文件夹中。
~/mmdetection3d/configs$ tree -L 1
.
├── 3dssd
├── _base_
├── benchmark
├── centerpoint
├── cylinder3d
├── dfm
├── dgcnn
├── dynamic_voxelization
├── fcaf3d
├── fcos3d
├── free_anchor
├── groupfree3d
├── h3dnet
├── imvotenet
├── imvoxelnet
├── minkunet
├── monoflex
├── mvxnet
├── nuimages
├── paconv
├── parta2
├── pgd
├── pointnet2
├── pointpillars
├── point_rcnn
├── pv_rcnn
├── regnet
├── sassd
├── second
├── smoke
├── spvcnn
├── ssn
└── votenet
2.设置work_dir相关参数
如果命令行设置了work_dir,就将cfg.work_dir = args.work_dir;如果work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录。
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
# 当 work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
3.设置优化器(混合精度)
涉及到命令行参数为–amp和配置文件中的optim_wrapper.type。
# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.type
if optim_wrapper == 'AmpOptimWrapper':
print_log(
'AMP training is already enabled in your config.',
logger='current',
level=logging.WARNING)
else:
assert optim_wrapper == 'OptimWrapper', (
'`--amp` is only supported when the optimizer wrapper type is '
f'`OptimWrapper` but got {optim_wrapper}.')
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
以configs/base/schedules/cyclic-20e.py的auto_scale_lr为例:
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01),
clip_grad=dict(max_norm=35, norm_type=2))
4.设置学习率和断点续训练
涉及到命令行参数–auto-scale-lr和配置文件参数auto_scale_lr,以configs/base/schedules/cyclic-20e.py的auto_scale_lr为例:
auto_scale_lr = dict(enable=False, base_batch_size=32)
如果arg.auto-scale-lr设置为true,enable会设置为true。
# 设置学习率
if args.auto_scale_lr:
if 'auto_scale_lr' in cfg and \
'enable' in cfg.auto_scale_lr and \
'base_batch_size' in cfg.auto_scale_lr:
cfg.auto_scale_lr.enable = True
else:
raise RuntimeError('Can not find "auto_scale_lr" or '
'"auto_scale_lr.enable" or '
'"auto_scale_lr.base_batch_size" in your'
' configuration file.')
设置断点续训练,如果指定checkpoint路径,则从中恢复,如果未指定,尝试从最新checkpoint自动恢复。
# 设置断点续训练
if args.resume == 'auto':
cfg.resume = True
cfg.load_from = None
elif args.resume is not None:
cfg.resume = True
cfg.load_from = args.resume
5.构建网络模型
这一部分比较复杂,涉及到mmengine的知识(mmengine的介绍可以看这篇博客链接:MMEngine理解,通过判断’runner_type’ ,来确定是使用mmengine.runner中from_cfg()还是mmengine.registry中的build()来创建runner对象。
# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
5.1 mmengine.runner.from_cfg()
在进入到from_cfg之前需要先知道从train.py中导入Runner这个类的过程中具体发生了什么?
from mmengine.runner import Runner
首先Runner这个类被@RUNNERS.register_module()所修饰,那就是Runner这个类会作为参数输入到RUNNERS.register_module(Runner)中,
@RUNNERS.register_module()
class Runner:
这又需要跟进的看一下RUNNERS怎么导入的,又是如何生成的?
导入:
from mmengine.registry import RUNNERS
生成:mmengine/registry/root.py 14行,可以看出RUNNERS是一个Registry的对象。
RUNNERS = Registry('runner', build_func=build_runner_from_cfg)
通过Registry的构造函数完成创建,粗略的看一下__init__()中初始化了一些属性。
def __init__(self,
name: str,
build_func: Optional[Callable] = None,
parent: Optional['Registry'] = None,
scope: Optional[str] = None,
locations: List = []):
from .build_functions import build_from_cfg
self._name = name
self._module_dict: Dict[str, Type] = dict()
self._children: Dict[str, 'Registry'] = dict()
self._locations = locations
self._imported = False
if scope is not None:
assert isinstance(scope, str)
self._scope = scope
else:
self._scope = self.infer_scope()
# See https://mypy.readthedocs.io/en/stable/common_issues.html#
# variables-vs-type-aliases for the use
self.parent: Optional['Registry']
if parent is not None:
assert isinstance(parent, Registry)
parent._add_child(self)
self.parent = parent
else:
self.parent = None
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
self.build_func: Callable
if build_func is None:
if self.parent is not None:
self.build_func = self.parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
继续回到RUNNERS.register_module(Runner),进入register_module(),由该函数的描述信息可知,这里的参数name默认none、force默认为false、module默认为none。所以需要执行self._register_module()。
def register_module(
self,
name: Optional[Union[str, List[str]]] = None,
force: bool = False,
module: Optional[Type] = None) -> Union[type, Callable]:
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be None, an instance of str, or a sequence of str, '
f'but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
self._register_module(module=module, module_name=name, force=force)
return module
return _register
进入到_register_module()中 ,粗略的看就是完成RUNNERS中_module_dict的填充,也就是_module_dict=[Runner=Runner]。
def _register_module(self,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = False) -> None:
Defaults to False.
if not callable(module):
raise TypeError(f'module must be Callable, but got {type(module)}')
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
existed_module = self.module_dict[name]
raise KeyError(f'{name} is already registered in {self.name} '
f'at {existed_module.__module__}')
self._module_dict[name] = module
from_cfg()实现在mmengine/runner/runner.py的451行。
@classmethod
def from_cfg(cls, cfg: ConfigType) -> 'Runner':
"""Build a runner from config.
Args:
cfg (ConfigType): A config used for building runner. Keys of
``cfg`` can see :meth:`__init__`.
Returns:
Runner: A runner build from ``cfg``.
"""
cfg = copy.deepcopy(cfg)
runner = cls(
model=cfg['model'],
work_dir=cfg['work_dir'],
train_dataloader=cfg.get('train_dataloader'),
val_dataloader=cfg.get('val_dataloader'),
test_dataloader=cfg.get('test_dataloader'),
train_cfg=cfg.get('train_cfg'),
val_cfg=cfg.get('val_cfg'),
test_cfg=cfg.get('test_cfg'),
auto_scale_lr=cfg.get('auto_scale_lr'),
optim_wrapper=cfg.get('optim_wrapper'),
param_scheduler=cfg.get('param_scheduler'),
val_evaluator=cfg.get('val_evaluator'),
test_evaluator=cfg.get('test_evaluator'),
default_hooks=cfg.get('default_hooks'),
custom_hooks=cfg.get('custom_hooks'),
data_preprocessor=cfg.get('data_preprocessor'),
load_from=cfg.get('load_from'),
resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl'))),
log_processor=cfg.get('log_processor'),
log_level=cfg.get('log_level', 'INFO'),
visualizer=cfg.get('visualizer'),
default_scope=cfg.get('default_scope', 'mmengine'),
randomness=cfg.get('randomness', dict(seed=None)),
experiment_name=cfg.get('experiment_name'),
cfg=cfg,
)
return runner
由于加了@classmethod的修饰符,所以runner = cls(…)这一部分的实现要看runner的__init__函数,粗略的看主要实现了runner对象属性的一些填充,其中最重要的model构建部分在runner.py代码429行。
def __init__(
self,
model: Union[nn.Module, Dict],
work_dir: str,
train_dataloader: Optional[Union[DataLoader, Dict]] = None,
val_dataloader: Optional[Union[DataLoader, Dict]] = None,
test_dataloader: Optional[Union[DataLoader, Dict]] = None,
train_cfg: Optional[Dict] = None,
val_cfg: Optional[Dict] = None,
test_cfg: Optional[Dict] = None,
auto_scale_lr: Optional[Dict] = None,
optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None,
param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,
custom_hooks: Optional[List[Union[Hook, Dict]]] = None,
data_preprocessor: Union[nn.Module, Dict, None] = None,
load_from: Optional[str] = None,
resume: bool = False,
launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
log_processor: Optional[Dict] = None,
log_level: str = 'INFO',
visualizer: Optional[Union[Visualizer, Dict]] = None,
default_scope: str = 'mmengine',
randomness: Dict = dict(seed=None),
experiment_name: Optional[str] = None,
cfg: Optional[ConfigType] = None,
):
self._work_dir = osp.abspath(work_dir)
mmengine.mkdir_or_exist(self._work_dir)
# recursively copy the `cfg` because `self.cfg` will be modified
# everywhere.
if cfg is not None:
if isinstance(cfg, Config):
self.cfg = copy.deepcopy(cfg)
elif isinstance(cfg, dict):
self.cfg = Config(cfg)
else:
self.cfg = Config(dict())
self._train_dataloader = train_dataloader
self._train_loop = train_cfg
self.optim_wrapper: Optional[Union[OptimWrapper, dict]]
self.optim_wrapper = optim_wrapper
self.auto_scale_lr = auto_scale_lr
self._check_scheduler_cfg(param_scheduler)
self.param_schedulers = param_scheduler
self._val_dataloader = val_dataloader
self._val_loop = val_cfg
self._val_evaluator = val_evaluator
self._test_dataloader = test_dataloader
self._test_loop = test_cfg
self._test_evaluator = test_evaluator
self._launcher = launcher
if self._launcher == 'none':
self._distributed = False
else:
self._distributed = True
# self._timestamp will be set in the `setup_env` method. Besides,
# it also will initialize multi-process and (or) distributed
# environment.
self.setup_env(env_cfg)
# self._deterministic and self._seed will be set in the
# `set_randomness`` method
self._randomness_cfg = randomness
self.set_randomness(**randomness)
if experiment_name is not None:
self._experiment_name = f'{experiment_name}_{self._timestamp}'
elif self.cfg.filename is not None:
filename_no_ext = osp.splitext(osp.basename(self.cfg.filename))[0]
self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
else:
self._experiment_name = self.timestamp
self._log_dir = osp.join(self.work_dir, self.timestamp)
mmengine.mkdir_or_exist(self._log_dir) self.default_scope = default_scope
# Build log processor to format message.
log_processor = dict() if log_processor is None else log_processor
self.log_processor = self.build_log_processor(log_processor)
# Since `get_instance` could return any subclass of ManagerMixin. The
# corresponding attribute needs a type hint.
self.logger = self.build_logger(log_level=log_level)
# Collect and log environment information.
self._log_env(env_cfg)
# See `MessageHub` and `ManagerMixin` for more details.
self.message_hub = self.build_message_hub()
# visualizer used for writing log or visualizing all kinds of data
self.visualizer = self.build_visualizer(visualizer)
if self.cfg:
self.visualizer.add_config(self.cfg)
self._load_from = load_from
self._resume = resume
# flag to mark whether checkpoint has been loaded or resumed
self._has_loaded = False
# build a model
self.model = self.build_model(model)
# wrap model
self.model = self.wrap_model(
self.cfg.get('model_wrapper_cfg'), self.model)
# get model name from the model class
if hasattr(self.model, 'module'):
self._model_name = self.model.module.__class__.__name__
else:
self._model_name = self.model.__class__.__name__
self._hooks: List[Hook] = []
# register hooks to `self._hooks`
self.register_hooks(default_hooks, custom_hooks)
# log hooks information
self.logger.info(f'Hooks will be executed in the following '
f'order:\n{self.get_hooks_info()}')
# dump `cfg` to `work_dir`
self.dump_config()
进入之build_model()中看一下,这个函数首先判断配置函数的model是nn.Module还是dict类型,如果是nn.Moudle就直接返回;如果是dict就需要调用MODELS.build()。
def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
if isinstance(model, nn.Module):
return model
elif isinstance(model, dict):
model = MODELS.build(model)
return model # type: ignore
else:
raise TypeError('model should be a nn.Module object or dict, '
f'but got {model}')
直接进入build()中,这里只调用了build_func()。
def build(self, cfg: dict, *args, **kwargs) -> Any:
return self.build_func(cfg, *args, **kwargs, registry=self)
这里的build_func()是在MOELS注册中指定的,即build_model_from_cfg(),进入到build_model_from_cfg()中,mmengine/registry/build_functions.py 206行。可以看到函数里主要调用的是 build_from_cfg()。
def build_model_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, 'ConfigDict', 'Config']] = None
) -> 'nn.Module':
from ..model import Sequential
if isinstance(cfg, list):
modules = [
build_from_cfg(_cfg, registry, default_args) for _cfg in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
build_from_cfg()函数做的事情就是从配置文件构建model,这里就需要了解一下配置文件,以centerpoint_pillar02_second_secfpn_8xb4-cyclic-20e_nus-3d.py为例,这个配置文件会先执行__base__中的配置文件。
_base_ = [
'../_base_/datasets/nus-3d.py',
'../_base_/models/centerpoint_pillar02_second_secfpn_nus.py',
'../_base_/schedules/cyclic-20e.py', '../_base_/default_runtime.py'
]
......
经过以上阅读后,再进入到build_from_cfg(model),截取部分重点,是根据模型配置字典中的 type字段来从注册器类中索引出对应的类完成初始化。
def build_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
# Avoid circular import
from ..logging import print_log
args = cfg.copy()
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
if inspect.isclass(obj_cls) and \
issubclass(obj_cls, ManagerMixin): # type: ignore
obj = obj_cls.get_instance(**args) # type: ignore
else:
obj = obj_cls(**args) # type: ignore
return obj
以centerpoint_pillar02_second_secfpn_nus.py中的model为例,首先先将CenterPoint存到obj_type中,并将type这个字段移出args(obj_type = args.pop(‘type’)
),随后从注册器类中找到对应的类(obj_cls = registry.get(obj_type)
),最后完成初始化( obj = obj_cls(**args) )。这里也就是执行centerpoint的构造函数__init__(),**args也就是构造函数里参数。
model = dict(
type=CenterPoint,
data_preprocessor=dict(
type=Det3DDataPreprocessor,
voxel=True,
voxel_layer=dict(
max_num_points=20,
voxel_size=voxel_size,
max_voxels=(30000, 40000))),
pts_voxel_encoder=dict(
type=PillarFeatureNet,
in_channels=5,
feat_channels=[64],
with_distance=False,
voxel_size=(0.2, 0.2, 8),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
legacy=False),
pts_middle_encoder=dict(
type=PointPillarsScatter, in_channels=64, output_shape=(512, 512)),
pts_backbone=dict(
type=SECOND,
in_channels=64,
out_channels=[64, 128, 256],
layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type=Conv2d, bias=False)),
pts_neck=dict(
type=SECONDFPN,
in_channels=[64, 128, 256],
out_channels=[128, 128, 128],
upsample_strides=[0.5, 1, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
use_conv_for_no_stride=True),
pts_bbox_head=dict(
type=CenterHead,
in_channels=sum([128, 128, 128]),
tasks=[
dict(num_class=1, class_names=['car']),
dict(num_class=2, class_names=['truck', 'construction_vehicle']),
dict(num_class=2, class_names=['bus', 'trailer']),
dict(num_class=1, class_names=['barrier']),
dict(num_class=2, class_names=['motorcycle', 'bicycle']),
dict(num_class=2, class_names=['pedestrian', 'traffic_cone']),
],
common_heads=dict(
reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
share_conv_channel=64,
bbox_coder=dict(
type=CenterPointBBoxCoder,
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
max_num=500,
score_threshold=0.1,
out_size_factor=4,
voxel_size=voxel_size[:2],
code_size=9),
separate_head=dict(type=SeparateHead, init_bias=-2.19, final_kernel=3),
loss_cls=dict(type='mmdet.GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(
type='mmdet.L1Loss', reduction='mean', loss_weight=0.25),
norm_bbox=True),
# model training and testing settings
train_cfg=dict(
pts=dict(
grid_size=[512, 512, 1],
voxel_size=voxel_size,
out_size_factor=4,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2])),
test_cfg=dict(
pts=dict(
post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
max_per_img=500,
max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175],
score_threshold=0.1,
pc_range=[-51.2, -51.2],
out_size_factor=4,
voxel_size=voxel_size[:2],
nms_type='rotate',
pre_max_size=1000,
post_max_size=83,
nms_thr=0.2)))
继续进行centerpoint的构造函数__init__(),发现这个构造函数里调用了父类MVXTwoStageDetector的构造函数,在这个构造函数中完成了centerpoint各个子模块( data_preprocessor、 pts_voxel_encoder、 pts_middle_encode…)的初始化。
5.2 mmengine.registry.build()
如果配置文件里存在runner_type这个字段,就会执行 runner = RUNNERS.build(cfg),首先看一下RUNNERS是如何被注册的,重点是build_runner_from_cfg这个函数。
RUNNERS = Registry('runner', build_func=build_runner_from_cfg)
其中build(cfg)其实就是执行build_runner_from_cfg(),跳进去看看,粗略的看就是配置文件会先配置好runner。
def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
registry: Registry) -> 'Runner':
from ..config import Config, ConfigDict
from ..logging import print_log
assert isinstance(
cfg,
(dict, ConfigDict, Config
)), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
assert isinstance(
registry, Registry), ('registry should be a mmengine.Registry object',
f'but got {type(registry)}')
args = cfg.copy()
# Runner should be built under target scope, if `_scope_` is defined
# in cfg, current default scope should switch to specified scope
# temporarily.
scope = args.pop('_scope_', None)
with registry.switch_scope_and_registry(scope) as registry:
obj_type = args.get('runner_type', 'Runner')
if isinstance(obj_type, str):
runner_cls = registry.get(obj_type)
if runner_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry. '
f'Please check whether the value of `{obj_type}` is '
'correct or it was registered as expected. More details '
'can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
)
elif inspect.isclass(obj_type):
runner_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
runner = runner_cls.from_cfg(args) # type: ignore
print_log(
f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, its implementation can be found in'
f'{runner_cls.__module__}', # type: ignore
logger='current',
level=logging.DEBUG)
return runner
6.构建训练集,模型训练
这里部分功能在runner.train(),这一块比较绕可以先看看以下博客: MMEngine 之 Runner 调用流程浅析,train()里内容很多,这里主要说一下build_train_loop()、 train_loop.run() 这两个接口做的事情
def train(self) -> nn.Module:
self._train_loop = self.build_train_loop(self._train_loop)
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
self.scale_lr(self.optim_wrapper, self.auto_scale_lr)
if self.param_schedulers is not None:
self.param_schedulers = self.build_param_scheduler(self.param_schedulers)
if self._val_loop is not None:
self._val_loop = self.build_val_loop(self._val_loop)
self.call_hook('before_run')
self._init_model_weights()
self.load_or_resume()
self.optim_wrapper.initialize_count_status(self.model, self._train_loop.iter, self._train_loop.max_iters)
model = self.train_loop.run()
self.call_hook('after_run')
return model
根据上述代码绘制的流程图如下:
首先分析build_train_loop()
def build_train_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop:
if 'type' in loop_cfg:
loop = LOOPS.build(
loop_cfg,
default_args=dict(runner=self, dataloader=self._train_dataloader))
else:
by_epoch = loop_cfg.pop('by_epoch')
if by_epoch:
loop = EpochBasedTrainLoop(**loop_cfg, runner=self, dataloader=self._train_dataloader)
else:
loop = IterBasedTrainLoop(**loop_cfg, runner=self, dataloader=self._train_dataloader)
return loop
从上述代码片段可以看出,训练流程的构建主要涉及 EpochBasedTrainLoop 与 IterBasedTrainLoop 两种循环结构,分别对应按照 epoch 与 iteration 两种训练方式。
以 EpochBasedTrainLoop 类为例,其主要功能位于初始化 init 与 run 方法部分,以下为整理后的核心代码(精简)片段:
class EpochBasedTrainLoop(BaseLoop):
def __init__(self, runner, dataloader, max_epochs, val_begin, val_interval, dynamic_intervals):
super().__init__(runner, dataloader)
self._max_iters = self._max_epochs * len(self.dataloader)
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = self.dataloader.dataset.metainfo
self.dynamic_milestones, self.dynamic_intervals = calc_dynamic_intervals(self.val_interval, dynamic_intervals)
def run(self) -> torch.nn.Module:
self.runner.call_hook('before_train')
while self._epoch < self._max_epochs:
self.run_epoch()
self._decide_current_val_interval()
if (self.runner.val_loop is not None
and self._epoch >= self.val_begin
and self._epoch % self.val_interval == 0):
self.runner.val_loop.run()
self.runner.call_hook('after_train')
return self.runner.model
从上述代码可以看出, EpochBasedTrainLoop 类实际上是继承了基类 BaseLoop,进一步跟进去。
class BaseLoop(metaclass=ABCMeta):
def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None:
self._runner = runner
if isinstance(dataloader, dict):
# Determine whether or not different ranks use different seed.
diff_rank_seed = runner._randomness_cfg.get('diff_rank_seed', False)
self.dataloader = runner.build_dataloader(dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed)
else:
self.dataloader = dataloader
@property
def runner(self):
return self._runner
@abstractmethod
def run(self) -> Any:
"""Execute loop."""
此处,完成了 train_dataloader 的真正实例化操作,并且定义了抽象方法 run() 。
再次回到 EpochBasedTrainLoop 类的 run() 方法,现在总算是进入了真正的训练流程,为了方便理解,建议对照代码,同步参考官方提供的相关流程图。
这里再进一步贴出 run() 方法中的训练相关的 run_epoch() 方法:
def run_epoch(self) -> None:
self.runner.call_hook('before_train_epoch')
self.runner.model.train()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
self.runner.call_hook('after_train_epoch')
self._epoch += 1
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
self.runner.call_hook('before_train_iter', batch_idx=idx, data_batch=data_batch)
outputs = self.runner.model.train_step(data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.call_hook('after_train_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs)
self._iter += 1
从 run_iter 中可以明显看出,最底层会调用 model.train_step 方法,这里以centerpoint这个个模型为例,train_step()的是在BaseModel这个父类实现的,centerpoint的继承关系为:centerpoint->MVXTwoStageDetector->Base3DDetector->BaseDetector->BaseModel->BaseModule->nn.Module。
def train_step(self, data: Union[dict, tuple, list],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
data = self.data_preprocessor(data, True)
losses = self._run_forward(data, mode='loss') # type: ignore
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore
optim_wrapper.update_params(parsed_losses)
return log_vars
再调用_run_forword(),results = self(**data, mode=mode) 这句话的意思就是调用自己模型 的forword方法。
def _run_forward(self, data: Union[dict, tuple, list],
mode: str) -> Union[Dict[str, torch.Tensor], list]:
if isinstance(data, dict):
results = self(**data, mode=mode)
elif isinstance(data, (list, tuple)):
results = self(*data, mode=mode)
else:
raise TypeError('Output of `data_preprocessor` should be '
f'list, tuple or dict, but got {type(data)}')
return results
那么有个问题,为什么 self(**data, mode=mode) 会调用 model的forward 方法?
首先模型都继承于 torch.nn.modules.module.Module,所以找到Module的 __call__定义为 def _call_impl(),_call_imp()里面就是调用forword()。
#在PyTorch源码的torch/nn/modules/module.py文件中,__call__语句的类型注解为Callable[…, Any] = _call_impl。其中,Callable表示可调用类型,即等号右边应该是一个可调用类型,此处指的是_call_impl;Any是一种特殊的类型,它与所有类型兼容;[…,]表示可接受任意数量的参数。因此,__call__实际指向了_call_impl函数,调用__call__实际上是调用_call_impl函数。#
__call__ : Callable[..., Any] = _call_impl
def _call_impl(self, *input, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward.
if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
or _global_forward_hooks or _global_forward_pre_hooks):
return forward_call(*input, **kwargs)
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
if self._backward_hooks or _global_backward_hooks:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
if _global_forward_pre_hooks or self._forward_pre_hooks:
for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
if full_backward_hooks:
bw_hook = hooks.BackwardHook(self, full_backward_hooks)
input = bw_hook.setup_input_hook(input)
result = forward_call(*input, **kwargs)
if _global_forward_hooks or self._forward_hooks:
for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if bw_hook:
result = bw_hook.setup_output_hook(result)
# Handle the non-full backward hooks
if non_full_backward_hooks:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in non_full_backward_hooks:
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
return result
后面就是按照继承顺序执行forword()函数,具体的forword是BaseDetector中的forword(),这里会按照mode来。
def forward(self,
inputs: torch.Tensor,
data_samples: OptSampleList = None,
mode: str = 'tensor') -> ForwardResults:
if mode == 'loss':
return self.loss(inputs, data_samples)
elif mode == 'predict':
return self.predict(inputs, data_samples)
elif mode == 'tensor':
return self._forward(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
如果mode=loss,会执行MVXTwoStageDetector中的loss()函数,总共分为两个部分:提取特征、检测头返回loss。
def loss(self, batch_inputs_dict: Dict[List, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
batch_input_metas = [item.metainfo for item in batch_data_samples]
img_feats, pts_feats = self.extract_feat(batch_inputs_dict,
batch_input_metas)
losses = dict()
if pts_feats:
losses_pts = self.pts_bbox_head.loss(pts_feats, batch_data_samples,
**kwargs)
losses.update(losses_pts)
if img_feats:
losses_img = self.loss_imgs(img_feats, batch_data_samples)
losses.update(losses_img)
return losses
先进入.extract_feat()看看,按照配置文件,做一些维度转变后,进入到特征提取的函数中extract_img_feat(),extract_pts_feat()。
def extract_feat(self, batch_inputs_dict: dict,
batch_input_metas: List[dict]) -> tuple:
voxel_dict = batch_inputs_dict.get('voxels', None)
imgs = batch_inputs_dict.get('imgs', None)
points = batch_inputs_dict.get('points', None)
img_feats = self.extract_img_feat(imgs, batch_input_metas)
pts_feats = self.extract_pts_feat(
voxel_dict,
points=points,
img_feats=img_feats,
batch_input_metas=batch_input_metas)
return (img_feats, pts_feats)
以点云提取特征为例,进入到extract_pts_feat()中,可以看到输入分别经过PillarFeatureNet、PointPillarsScatter、SECONDFPN后输出特征图。
def extract_pts_feat(
self,
voxel_dict: Dict[str, Tensor],
points: Optional[List[Tensor]] = None,
img_feats: Optional[Sequence[Tensor]] = None,
batch_input_metas: Optional[List[dict]] = None
) -> Sequence[Tensor]:
if not self.with_pts_bbox:
return None
voxel_features = self.pts_voxel_encoder(voxel_dict['voxels'],
voxel_dict['num_points'],
voxel_dict['coors'], img_feats,
batch_input_metas)
batch_size = voxel_dict['coors'][-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, voxel_dict['coors'],
batch_size)
x = self.pts_backbone(x)
if self.with_pts_neck:
x = self.pts_neck(x)
return x
特征提取完毕之后就进行,检测头loss,这次选用的是centerhead。
附录:
config配置文件命名规则
MMDetection3D已经实现的配置文件都位于./configs文件夹下,配置文件都按照统一的规则命名,具体段的含义可以去官方文档链接: https://mmdetection.readthedocs.io/zh_CN/latest/tutorials/config.html#id4自行查阅。
#命名规则
{model}_[model setting]_{backbone}_{neck}_[norm setting]_[misc]_[gpu x batch_per_gpu]_{schedule}_{dataset}
#其中各个字段的含义
{model}: 模型种类,例如 faster_rcnn, mask_rcnn 等。
[model setting]: 特定的模型,例如 htc 中的without_semantic, reppoints 中的 moment 等。
{backbone}: 主干网络种类例如 r50 (ResNet-50), x101 (ResNeXt-101) 等。
{neck}: Neck 模型的种类包括 fpn, pafpn, nasfpn, c4 等。
[norm_setting]: 默认使用 bn (Batch Normalization),其他指定可以有 gn (Group Normalization), syncbn (Synchronized Batch Normalization) 等。 gn-head/gn-neck 表示 GN 仅应用于网络的 Head 或 Neck, gn-all 表示 GN 用于整个模型, 例如主干网络、Neck 和 Head。
[misc]: 模型中各式各样的设置/插件,例如 dconv、 gcb、 attention、albu、 mstrain 等。
[gpu x batch_per_gpu]:GPU 数量和每个 GPU 的样本数,默认使用 8x2。
{schedule}: 训练方案,选项是 1x、 2x、 20e 等。1x 和 2x 分别代表 12 epoch 和 24 epoch,20e 在级联模型中使用,表示 20 epoch。对于 1x/2x,初始学习率在第 8/16 和第 11/22 epoch 衰减 10 倍;对于 20e ,初始学习率在第 16 和第 19 epoch 衰减 10 倍。
{dataset}:数据集,例如 coco、 cityscapes、 voc_0712、 wider_face 等。
参考博客、文档
1、MMEngine理解
2、MMEngine 之 Runner 调用流程浅析
3、mmdetection之config文件
4、深入了解python函数装饰器在mmdetection中的使用(一)
5、带你玩转 3D 检测和分割(一):MMDetection3D 整体框架介绍
6、 轻松掌握 MMDetection 整体构建流程(一)
7、 轻松掌握 MMDetection 整体构建流程(二)