MMdetection3D-从train.py开始

先从train.py的main函数看起,main函数总共分为6个部分,分别为:
●解析命令行参数和配置文件
●设置work_dir相关参数
●设置优化器
●设置学习率和断点续训练
●构建网络模型、构建数据集、设置验证集和checkpoint相关参数
●模型训练

1.解析命令行参数和配置文件

命令行参数获取通过args = parse_args()实现,所涉及到参数包括:

参数说明
config配置文件路径
–work-dir输出文件路径
–amp实现自动混合精度训练
-auto-scale-lr启用自动缩放LR
—resume断点续训练权重文件路径
–ceph使用ceph作为数据存储后端
–cfg-options覆盖config中的配置
–launcher分布式训练相关参数
–local_rank本地进程编号,

通过 cfg = Config.fromfile(args.config)实现配置文件的导入。目前mmdetecton3D已实现模型的配置文件都放在configs文件夹中。

~/mmdetection3d/configs$ tree -L 1
.
├── 3dssd
├── _base_
├── benchmark
├── centerpoint
├── cylinder3d
├── dfm
├── dgcnn
├── dynamic_voxelization
├── fcaf3d
├── fcos3d
├── free_anchor
├── groupfree3d
├── h3dnet
├── imvotenet
├── imvoxelnet
├── minkunet
├── monoflex
├── mvxnet
├── nuimages
├── paconv
├── parta2
├── pgd
├── pointnet2
├── pointpillars
├── point_rcnn
├── pv_rcnn
├── regnet
├── sassd
├── second
├── smoke
├── spvcnn
├── ssn
└── votenet

2.设置work_dir相关参数

如果命令行设置了work_dir,就将cfg.work_dir = args.work_dir;如果work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录。

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
  # 当 work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])

3.设置优化器(混合精度)

涉及到命令行参数为–amp和配置文件中的optim_wrapper.type。

    # enable automatic-mixed-precision training
    if args.amp is True:
        optim_wrapper = cfg.optim_wrapper.type
        if optim_wrapper == 'AmpOptimWrapper':
            print_log(
                'AMP training is already enabled in your config.',
                logger='current',
                level=logging.WARNING)
        else:
            assert optim_wrapper == 'OptimWrapper', (
                '`--amp` is only supported when the optimizer wrapper type is '
                f'`OptimWrapper` but got {optim_wrapper}.')
            cfg.optim_wrapper.type = 'AmpOptimWrapper'
            cfg.optim_wrapper.loss_scale = 'dynamic'
         

以configs/base/schedules/cyclic-20e.py的auto_scale_lr为例:

optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01),
    clip_grad=dict(max_norm=35, norm_type=2))

4.设置学习率和断点续训练

涉及到命令行参数–auto-scale-lr和配置文件参数auto_scale_lr,以configs/base/schedules/cyclic-20e.py的auto_scale_lr为例:

auto_scale_lr = dict(enable=False, base_batch_size=32)

如果arg.auto-scale-lr设置为true,enable会设置为true。

    # 设置学习率
    if args.auto_scale_lr:
        if 'auto_scale_lr' in cfg and \
                'enable' in cfg.auto_scale_lr and \
                'base_batch_size' in cfg.auto_scale_lr:
            cfg.auto_scale_lr.enable = True
        else:
            raise RuntimeError('Can not find "auto_scale_lr" or '
                               '"auto_scale_lr.enable" or '
                               '"auto_scale_lr.base_batch_size" in your'
                               ' configuration file.')

设置断点续训练,如果指定checkpoint路径,则从中恢复,如果未指定,尝试从最新checkpoint自动恢复。

 # 设置断点续训练
    if args.resume == 'auto':
        cfg.resume = True
        cfg.load_from = None
    elif args.resume is not None:
        cfg.resume = True
        cfg.load_from = args.resume

5.构建网络模型

这一部分比较复杂,涉及到mmengine的知识(mmengine的介绍可以看这篇博客链接:MMEngine理解,通过判断’runner_type’ ,来确定是使用mmengine.runner中from_cfg()还是mmengine.registry中的build()来创建runner对象。

    # build the runner from config
    if 'runner_type' not in cfg:
        # build the default runner
        runner = Runner.from_cfg(cfg)
    else:
        # build customized runner from the registry
        # if 'runner_type' is set in the cfg
        runner = RUNNERS.build(cfg)

5.1 mmengine.runner.from_cfg()

在进入到from_cfg之前需要先知道从train.py中导入Runner这个类的过程中具体发生了什么?

from mmengine.runner import Runner

首先Runner这个类被@RUNNERS.register_module()所修饰,那就是Runner这个类会作为参数输入到RUNNERS.register_module(Runner)中,

@RUNNERS.register_module()
class Runner:

这又需要跟进的看一下RUNNERS怎么导入的,又是如何生成的?
导入:

from mmengine.registry import RUNNERS

生成:mmengine/registry/root.py 14行,可以看出RUNNERS是一个Registry的对象。

RUNNERS = Registry('runner', build_func=build_runner_from_cfg)

通过Registry的构造函数完成创建,粗略的看一下__init__()中初始化了一些属性。

 def __init__(self,
                 name: str,
                 build_func: Optional[Callable] = None,
                 parent: Optional['Registry'] = None,
                 scope: Optional[str] = None,
                 locations: List = []):
        from .build_functions import build_from_cfg
        self._name = name
        self._module_dict: Dict[str, Type] = dict()
        self._children: Dict[str, 'Registry'] = dict()
        self._locations = locations
        self._imported = False

        if scope is not None:
            assert isinstance(scope, str)
            self._scope = scope
        else:
            self._scope = self.infer_scope()

        # See https://mypy.readthedocs.io/en/stable/common_issues.html#
        # variables-vs-type-aliases for the use
        self.parent: Optional['Registry']
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_child(self)
            self.parent = parent
        else:
            self.parent = None

        # self.build_func will be set with the following priority:
        # 1. build_func
        # 2. parent.build_func
        # 3. build_from_cfg
        self.build_func: Callable
        if build_func is None:
            if self.parent is not None:
                self.build_func = self.parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func

继续回到RUNNERS.register_module(Runner),进入register_module(),由该函数的描述信息可知,这里的参数name默认none、force默认为false、module默认为none。所以需要执行self._register_module()。

    def register_module(
            self,
            name: Optional[Union[str, List[str]]] = None,
            force: bool = False,
            module: Optional[Type] = None) -> Union[type, Callable]:
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {type(force)}')

        # raise the error ahead of time
        if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
            raise TypeError(
                'name must be None, an instance of str, or a sequence of str, '
                f'but got {type(name)}')

        # use it as a normal method: x.register_module(module=SomeClass)
        if module is not None:
            self._register_module(module=module, module_name=name, force=force)
            return module

        # use it as a decorator: @x.register_module()
        def _register(module):
            self._register_module(module=module, module_name=name, force=force)
            return module

        return _register

进入到_register_module()中 ,粗略的看就是完成RUNNERS中_module_dict的填充,也就是_module_dict=[Runner=Runner]。

    def _register_module(self,
                         module: Type,
                         module_name: Optional[Union[str, List[str]]] = None,
                         force: bool = False) -> None:
 Defaults to False.
        if not callable(module):
            raise TypeError(f'module must be Callable, but got {type(module)}')

        if module_name is None:
            module_name = module.__name__
        if isinstance(module_name, str):
            module_name = [module_name]
        for name in module_name:
            if not force and name in self._module_dict:
                existed_module = self.module_dict[name]
                raise KeyError(f'{name} is already registered in {self.name} '
                               f'at {existed_module.__module__}')
            self._module_dict[name] = module

from_cfg()实现在mmengine/runner/runner.py的451行。

    @classmethod
    def from_cfg(cls, cfg: ConfigType) -> 'Runner':
        """Build a runner from config.

        Args:
            cfg (ConfigType): A config used for building runner. Keys of
                ``cfg`` can see :meth:`__init__`.

        Returns:
            Runner: A runner build from ``cfg``.
        """
        cfg = copy.deepcopy(cfg)
        runner = cls(
            model=cfg['model'],
            work_dir=cfg['work_dir'],
            train_dataloader=cfg.get('train_dataloader'),
            val_dataloader=cfg.get('val_dataloader'),
            test_dataloader=cfg.get('test_dataloader'),
            train_cfg=cfg.get('train_cfg'),
            val_cfg=cfg.get('val_cfg'),
            test_cfg=cfg.get('test_cfg'),
            auto_scale_lr=cfg.get('auto_scale_lr'),
            optim_wrapper=cfg.get('optim_wrapper'),
            param_scheduler=cfg.get('param_scheduler'),
            val_evaluator=cfg.get('val_evaluator'),
            test_evaluator=cfg.get('test_evaluator'),
            default_hooks=cfg.get('default_hooks'),
            custom_hooks=cfg.get('custom_hooks'),
            data_preprocessor=cfg.get('data_preprocessor'),
            load_from=cfg.get('load_from'),
            resume=cfg.get('resume', False),
            launcher=cfg.get('launcher', 'none'),
            env_cfg=cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl'))),
            log_processor=cfg.get('log_processor'),
            log_level=cfg.get('log_level', 'INFO'),
            visualizer=cfg.get('visualizer'),
            default_scope=cfg.get('default_scope', 'mmengine'),
            randomness=cfg.get('randomness', dict(seed=None)),
            experiment_name=cfg.get('experiment_name'),
            cfg=cfg,
        )

        return runner

由于加了@classmethod的修饰符,所以runner = cls(…)这一部分的实现要看runner的__init__函数,粗略的看主要实现了runner对象属性的一些填充,其中最重要的model构建部分在runner.py代码429行。

 def __init__(
        self,
        model: Union[nn.Module, Dict],
        work_dir: str,
        train_dataloader: Optional[Union[DataLoader, Dict]] = None,
        val_dataloader: Optional[Union[DataLoader, Dict]] = None,
        test_dataloader: Optional[Union[DataLoader, Dict]] = None,
        train_cfg: Optional[Dict] = None,
        val_cfg: Optional[Dict] = None,
        test_cfg: Optional[Dict] = None,
        auto_scale_lr: Optional[Dict] = None,
        optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None,
        param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
        val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
        test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
        default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,
        custom_hooks: Optional[List[Union[Hook, Dict]]] = None,
        data_preprocessor: Union[nn.Module, Dict, None] = None,
        load_from: Optional[str] = None,
        resume: bool = False,
        launcher: str = 'none',
        env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
        log_processor: Optional[Dict] = None,
        log_level: str = 'INFO',
        visualizer: Optional[Union[Visualizer, Dict]] = None,
        default_scope: str = 'mmengine',
        randomness: Dict = dict(seed=None),
        experiment_name: Optional[str] = None,
        cfg: Optional[ConfigType] = None,
    ):
        self._work_dir = osp.abspath(work_dir)
        mmengine.mkdir_or_exist(self._work_dir)

        # recursively copy the `cfg` because `self.cfg` will be modified
        # everywhere.
        if cfg is not None:
            if isinstance(cfg, Config):
                self.cfg = copy.deepcopy(cfg)
            elif isinstance(cfg, dict):
                self.cfg = Config(cfg)
        else:
            self.cfg = Config(dict())

        self._train_dataloader = train_dataloader
        self._train_loop = train_cfg

        self.optim_wrapper: Optional[Union[OptimWrapper, dict]]
        self.optim_wrapper = optim_wrapper

        self.auto_scale_lr = auto_scale_lr

        self._check_scheduler_cfg(param_scheduler)
        self.param_schedulers = param_scheduler

        self._val_dataloader = val_dataloader
        self._val_loop = val_cfg
        self._val_evaluator = val_evaluator

        self._test_dataloader = test_dataloader
        self._test_loop = test_cfg
        self._test_evaluator = test_evaluator

        self._launcher = launcher
        if self._launcher == 'none':
            self._distributed = False
        else:
            self._distributed = True

        # self._timestamp will be set in the `setup_env` method. Besides,
        # it also will initialize multi-process and (or) distributed
        # environment.
        self.setup_env(env_cfg)
        # self._deterministic and self._seed will be set in the
        # `set_randomness`` method
        self._randomness_cfg = randomness
        self.set_randomness(**randomness)

        if experiment_name is not None:
            self._experiment_name = f'{experiment_name}_{self._timestamp}'
        elif self.cfg.filename is not None:
            filename_no_ext = osp.splitext(osp.basename(self.cfg.filename))[0]
            self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
        else:
            self._experiment_name = self.timestamp
        self._log_dir = osp.join(self.work_dir, self.timestamp)
        mmengine.mkdir_or_exist(self._log_dir)                                 self.default_scope = default_scope
        # Build log processor to format message.
        log_processor = dict() if log_processor is None else log_processor
        self.log_processor = self.build_log_processor(log_processor)
        # Since `get_instance` could return any subclass of ManagerMixin. The
        # corresponding attribute needs a type hint.
        self.logger = self.build_logger(log_level=log_level)

        # Collect and log environment information.
        self._log_env(env_cfg)

        # See `MessageHub` and `ManagerMixin` for more details.
        self.message_hub = self.build_message_hub()
        # visualizer used for writing log or visualizing all kinds of data
        self.visualizer = self.build_visualizer(visualizer)
        if self.cfg:
            self.visualizer.add_config(self.cfg)

        self._load_from = load_from
        self._resume = resume
        # flag to mark whether checkpoint has been loaded or resumed
        self._has_loaded = False

        # build a model
        self.model = self.build_model(model)
        # wrap model
        self.model = self.wrap_model(
            self.cfg.get('model_wrapper_cfg'), self.model)

        # get model name from the model class
        if hasattr(self.model, 'module'):
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

        self._hooks: List[Hook] = []
        # register hooks to `self._hooks`
        self.register_hooks(default_hooks, custom_hooks)
        # log hooks information
        self.logger.info(f'Hooks will be executed in the following '
                         f'order:\n{self.get_hooks_info()}')

        # dump `cfg` to `work_dir`
        self.dump_config()

进入之build_model()中看一下,这个函数首先判断配置函数的model是nn.Module还是dict类型,如果是nn.Moudle就直接返回;如果是dict就需要调用MODELS.build()。

    def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
        if isinstance(model, nn.Module):
            return model
        elif isinstance(model, dict):
            model = MODELS.build(model)
            return model  # type: ignore
        else:
            raise TypeError('model should be a nn.Module object or dict, '
                            f'but got {model}')

直接进入build()中,这里只调用了build_func()。

    def build(self, cfg: dict, *args, **kwargs) -> Any:
        return self.build_func(cfg, *args, **kwargs, registry=self)

这里的build_func()是在MOELS注册中指定的,即build_model_from_cfg(),进入到build_model_from_cfg()中,mmengine/registry/build_functions.py 206行。可以看到函数里主要调用的是 build_from_cfg()。

def build_model_from_cfg(
    cfg: Union[dict, ConfigDict, Config],
    registry: Registry,
    default_args: Optional[Union[dict, 'ConfigDict', 'Config']] = None
) -> 'nn.Module':
    from ..model import Sequential
    if isinstance(cfg, list):
        modules = [
            build_from_cfg(_cfg, registry, default_args) for _cfg in cfg
        ]
        return Sequential(*modules)
    else:
        return build_from_cfg(cfg, registry, default_args)

build_from_cfg()函数做的事情就是从配置文件构建model,这里就需要了解一下配置文件,以centerpoint_pillar02_second_secfpn_8xb4-cyclic-20e_nus-3d.py为例,这个配置文件会先执行__base__中的配置文件。

_base_ = [
    '../_base_/datasets/nus-3d.py',
    '../_base_/models/centerpoint_pillar02_second_secfpn_nus.py',
    '../_base_/schedules/cyclic-20e.py', '../_base_/default_runtime.py'
]

......

经过以上阅读后,再进入到build_from_cfg(model),截取部分重点,是根据模型配置字典中的 type字段来从注册器类中索引出对应的类完成初始化。

 def build_from_cfg(
        cfg: Union[dict, ConfigDict, Config],
        registry: Registry,
        default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
    # Avoid circular import
    from ..logging import print_log

        args = cfg.copy()
        obj_type = args.pop('type')
        if isinstance(obj_type, str):
            obj_cls = registry.get(obj_type)
            if obj_cls is None:
         if inspect.isclass(obj_cls) and \
                issubclass(obj_cls, ManagerMixin):  # type: ignore
            obj = obj_cls.get_instance(**args)  # type: ignore
        else:
            obj = obj_cls(**args)  # type: ignore

               return obj

以centerpoint_pillar02_second_secfpn_nus.py中的model为例,首先先将CenterPoint存到obj_type中,并将type这个字段移出args(obj_type = args.pop(‘type’)
),随后从注册器类中找到对应的类(obj_cls = registry.get(obj_type)
),最后完成初始化( obj = obj_cls(**args) )。这里也就是执行centerpoint的构造函数__init__(),**args也就是构造函数里参数。

model = dict(
    type=CenterPoint,
    data_preprocessor=dict(
        type=Det3DDataPreprocessor,
        voxel=True,
        voxel_layer=dict(
            max_num_points=20,
            voxel_size=voxel_size,
            max_voxels=(30000, 40000))),
    pts_voxel_encoder=dict(
        type=PillarFeatureNet,
        in_channels=5,
        feat_channels=[64],
        with_distance=False,
        voxel_size=(0.2, 0.2, 8),
        norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
        legacy=False),
    pts_middle_encoder=dict(
        type=PointPillarsScatter, in_channels=64, output_shape=(512, 512)),
    pts_backbone=dict(
        type=SECOND,
        in_channels=64,
        out_channels=[64, 128, 256],
        layer_nums=[3, 5, 5],
        layer_strides=[2, 2, 2],
        norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
        conv_cfg=dict(type=Conv2d, bias=False)),
    pts_neck=dict(
        type=SECONDFPN,
        in_channels=[64, 128, 256],
        out_channels=[128, 128, 128],
        upsample_strides=[0.5, 1, 2],
        norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
        upsample_cfg=dict(type='deconv', bias=False),
        use_conv_for_no_stride=True),
    pts_bbox_head=dict(
        type=CenterHead,
        in_channels=sum([128, 128, 128]),
        tasks=[
            dict(num_class=1, class_names=['car']),
            dict(num_class=2, class_names=['truck', 'construction_vehicle']),
            dict(num_class=2, class_names=['bus', 'trailer']),
            dict(num_class=1, class_names=['barrier']),
            dict(num_class=2, class_names=['motorcycle', 'bicycle']),
            dict(num_class=2, class_names=['pedestrian', 'traffic_cone']),
        ],
        common_heads=dict(
            reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
        share_conv_channel=64,
        bbox_coder=dict(
            type=CenterPointBBoxCoder,
            post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
            max_num=500,
            score_threshold=0.1,
            out_size_factor=4,
            voxel_size=voxel_size[:2],
            code_size=9),
        separate_head=dict(type=SeparateHead, init_bias=-2.19, final_kernel=3),
        loss_cls=dict(type='mmdet.GaussianFocalLoss', reduction='mean'),
        loss_bbox=dict(
            type='mmdet.L1Loss', reduction='mean', loss_weight=0.25),
        norm_bbox=True),
    # model training and testing settings
    train_cfg=dict(
        pts=dict(
            grid_size=[512, 512, 1],
            voxel_size=voxel_size,
            out_size_factor=4,
            dense_reg=1,
            gaussian_overlap=0.1,
            max_objs=500,
            min_radius=2,
            code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2])),
    test_cfg=dict(
        pts=dict(
            post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
            max_per_img=500,
            max_pool_nms=False,
            min_radius=[4, 12, 10, 1, 0.85, 0.175],
            score_threshold=0.1,
            pc_range=[-51.2, -51.2],
            out_size_factor=4,
            voxel_size=voxel_size[:2],
            nms_type='rotate',
            pre_max_size=1000,
            post_max_size=83,
            nms_thr=0.2)))

继续进行centerpoint的构造函数__init__(),发现这个构造函数里调用了父类MVXTwoStageDetector的构造函数,在这个构造函数中完成了centerpoint各个子模块( data_preprocessor、 pts_voxel_encoder、 pts_middle_encode…)的初始化。

5.2 mmengine.registry.build()

如果配置文件里存在runner_type这个字段,就会执行 runner = RUNNERS.build(cfg),首先看一下RUNNERS是如何被注册的,重点是build_runner_from_cfg这个函数。

RUNNERS = Registry('runner', build_func=build_runner_from_cfg)

其中build(cfg)其实就是执行build_runner_from_cfg(),跳进去看看,粗略的看就是配置文件会先配置好runner。

def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
                          registry: Registry) -> 'Runner':
    from ..config import Config, ConfigDict
    from ..logging import print_log

    assert isinstance(
        cfg,
        (dict, ConfigDict, Config
         )), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
    assert isinstance(
        registry, Registry), ('registry should be a mmengine.Registry object',
                              f'but got {type(registry)}')

    args = cfg.copy()
    # Runner should be built under target scope, if `_scope_` is defined
    # in cfg, current default scope should switch to specified scope
    # temporarily.
    scope = args.pop('_scope_', None)
    with registry.switch_scope_and_registry(scope) as registry:
        obj_type = args.get('runner_type', 'Runner')
        if isinstance(obj_type, str):
            runner_cls = registry.get(obj_type)
            if runner_cls is None:
                raise KeyError(
                    f'{obj_type} is not in the {registry.name} registry. '
                    f'Please check whether the value of `{obj_type}` is '
                    'correct or it was registered as expected. More details '
                    'can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module'  # noqa: E501
                )
        elif inspect.isclass(obj_type):
            runner_cls = obj_type
        else:
            raise TypeError(
                f'type must be a str or valid type, but got {type(obj_type)}')

        runner = runner_cls.from_cfg(args)  # type: ignore
        print_log(
            f'An `{runner_cls.__name__}` instance is built from '  # type: ignore # noqa: E501
            'registry, its implementation can be found in'
            f'{runner_cls.__module__}',  # type: ignore
            logger='current',
            level=logging.DEBUG)
        return runner

6.构建训练集,模型训练

这里部分功能在runner.train(),这一块比较绕可以先看看以下博客: MMEngine 之 Runner 调用流程浅析,train()里内容很多,这里主要说一下build_train_loop()、 train_loop.run() 这两个接口做的事情

def train(self) -> nn.Module:
    self._train_loop = self.build_train_loop(self._train_loop)

    self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
    self.scale_lr(self.optim_wrapper, self.auto_scale_lr)
    if self.param_schedulers is not None:
        self.param_schedulers = self.build_param_scheduler(self.param_schedulers)

    if self._val_loop is not None:
        self._val_loop = self.build_val_loop(self._val_loop)

    self.call_hook('before_run')

    self._init_model_weights()
    self.load_or_resume()
    self.optim_wrapper.initialize_count_status(self.model, self._train_loop.iter, self._train_loop.max_iters)

    model = self.train_loop.run()

    self.call_hook('after_run')
    return model

根据上述代码绘制的流程图如下:
在这里插入图片描述首先分析build_train_loop()

def build_train_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop:

    if 'type' in loop_cfg:
        loop = LOOPS.build(
            loop_cfg,
            default_args=dict(runner=self, dataloader=self._train_dataloader))
    else:
        by_epoch = loop_cfg.pop('by_epoch')
        if by_epoch:
            loop = EpochBasedTrainLoop(**loop_cfg, runner=self, dataloader=self._train_dataloader)
        else:
            loop = IterBasedTrainLoop(**loop_cfg, runner=self, dataloader=self._train_dataloader)
    return loop

从上述代码片段可以看出,训练流程的构建主要涉及 EpochBasedTrainLoop 与 IterBasedTrainLoop 两种循环结构,分别对应按照 epoch 与 iteration 两种训练方式。
以 EpochBasedTrainLoop 类为例,其主要功能位于初始化 init 与 run 方法部分,以下为整理后的核心代码(精简)片段:

class EpochBasedTrainLoop(BaseLoop):

    def __init__(self, runner, dataloader, max_epochs, val_begin, val_interval, dynamic_intervals):

        super().__init__(runner, dataloader)
        self._max_iters = self._max_epochs * len(self.dataloader)

        if hasattr(self.dataloader.dataset, 'metainfo'):
            self.runner.visualizer.dataset_meta = self.dataloader.dataset.metainfo

        self.dynamic_milestones, self.dynamic_intervals = calc_dynamic_intervals(self.val_interval, dynamic_intervals)

    def run(self) -> torch.nn.Module:

        self.runner.call_hook('before_train')

        while self._epoch < self._max_epochs:
            self.run_epoch()

            self._decide_current_val_interval()
            if (self.runner.val_loop is not None
                    and self._epoch >= self.val_begin
                    and self._epoch % self.val_interval == 0):
                self.runner.val_loop.run()

        self.runner.call_hook('after_train')
        return self.runner.model
        

从上述代码可以看出, EpochBasedTrainLoop 类实际上是继承了基类 BaseLoop,进一步跟进去。

class BaseLoop(metaclass=ABCMeta):
    def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None:
        self._runner = runner
        if isinstance(dataloader, dict):
            # Determine whether or not different ranks use different seed.
            diff_rank_seed = runner._randomness_cfg.get('diff_rank_seed', False)
            self.dataloader = runner.build_dataloader(dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed)
        else:
            self.dataloader = dataloader

    @property
    def runner(self):
        return self._runner

    @abstractmethod
    def run(self) -> Any:
        """Execute loop."""

此处,完成了 train_dataloader 的真正实例化操作,并且定义了抽象方法 run() 。
再次回到 EpochBasedTrainLoop 类的 run() 方法,现在总算是进入了真正的训练流程,为了方便理解,建议对照代码,同步参考官方提供的相关流程图。
在这里插入图片描述这里再进一步贴出 run() 方法中的训练相关的 run_epoch() 方法:

def run_epoch(self) -> None:

    self.runner.call_hook('before_train_epoch')
    self.runner.model.train()
    for idx, data_batch in enumerate(self.dataloader):
        self.run_iter(idx, data_batch)

    self.runner.call_hook('after_train_epoch')
    self._epoch += 1

def run_iter(self, idx, data_batch: Sequence[dict]) -> None:

    self.runner.call_hook('before_train_iter', batch_idx=idx, data_batch=data_batch)

    outputs = self.runner.model.train_step(data_batch, optim_wrapper=self.runner.optim_wrapper)

    self.runner.call_hook('after_train_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs)
    self._iter += 1

从 run_iter 中可以明显看出,最底层会调用 model.train_step 方法,这里以centerpoint这个个模型为例,train_step()的是在BaseModel这个父类实现的,centerpoint的继承关系为:centerpoint->MVXTwoStageDetector->Base3DDetector->BaseDetector->BaseModel->BaseModule->nn.Module。

 def train_step(self, data: Union[dict, tuple, list],
                   optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
        # Enable automatic mixed precision training context.
        with optim_wrapper.optim_context(self):
            data = self.data_preprocessor(data, True)
            losses = self._run_forward(data, mode='loss')  # type: ignore
        parsed_losses, log_vars = self.parse_losses(losses)  # type: ignore
        optim_wrapper.update_params(parsed_losses)
        return log_vars

再调用_run_forword(),results = self(**data, mode=mode) 这句话的意思就是调用自己模型 的forword方法。

def _run_forward(self, data: Union[dict, tuple, list],
                     mode: str) -> Union[Dict[str, torch.Tensor], list]:
        if isinstance(data, dict):
            results = self(**data, mode=mode)
        elif isinstance(data, (list, tuple)):
            results = self(*data, mode=mode)
        else:
            raise TypeError('Output of `data_preprocessor` should be '
                            f'list, tuple or dict, but got {type(data)}')
        return results
        

那么有个问题,为什么 self(**data, mode=mode) 会调用 model的forward 方法?
首先模型都继承于 torch.nn.modules.module.Module,所以找到Module的 __call__定义为 def _call_impl(),_call_imp()里面就是调用forword()。

    #在PyTorch源码的torch/nn/modules/module.py文件中,__call__语句的类型注解为Callable[…, Any] = _call_impl。其中,Callable表示可调用类型,即等号右边应该是一个可调用类型,此处指的是_call_impl;Any是一种特殊的类型,它与所有类型兼容;[…,]表示可接受任意数量的参数。因此,__call__实际指向了_call_impl函数,调用__call__实际上是调用_call_impl函数。#
    __call__ : Callable[..., Any] = _call_impl
    
def _call_impl(self, *input, **kwargs):
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        # If we don't have any hooks, we want to skip the rest of the logic in
        # this function, and just call forward.
        if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
                or _global_forward_hooks or _global_forward_pre_hooks):
            return forward_call(*input, **kwargs)
        # Do not call functions when jit is used
        full_backward_hooks, non_full_backward_hooks = [], []
        if self._backward_hooks or _global_backward_hooks:
            full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
        if _global_forward_pre_hooks or self._forward_pre_hooks:
            for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
                result = hook(self, input)
                if result is not None:
                    if not isinstance(result, tuple):
                        result = (result,)
                    input = result

        bw_hook = None
        if full_backward_hooks:
            bw_hook = hooks.BackwardHook(self, full_backward_hooks)
            input = bw_hook.setup_input_hook(input)

        result = forward_call(*input, **kwargs)
        if _global_forward_hooks or self._forward_hooks:
            for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
                hook_result = hook(self, input, result)
                if hook_result is not None:
                    result = hook_result

        if bw_hook:
            result = bw_hook.setup_output_hook(result)

        # Handle the non-full backward hooks
        if non_full_backward_hooks:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in non_full_backward_hooks:
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
                self._maybe_warn_non_full_backward_hook(input, result, grad_fn)

        return result

后面就是按照继承顺序执行forword()函数,具体的forword是BaseDetector中的forword(),这里会按照mode来。

   def forward(self,
                inputs: torch.Tensor,
                data_samples: OptSampleList = None,
                mode: str = 'tensor') -> ForwardResults:
        if mode == 'loss':
            return self.loss(inputs, data_samples)
        elif mode == 'predict':
            return self.predict(inputs, data_samples)
        elif mode == 'tensor':
            return self._forward(inputs, data_samples)
        else:
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')

如果mode=loss,会执行MVXTwoStageDetector中的loss()函数,总共分为两个部分:提取特征、检测头返回loss。

    def loss(self, batch_inputs_dict: Dict[List, torch.Tensor],
             batch_data_samples: List[Det3DDataSample],
             **kwargs) -> List[Det3DDataSample]:
        batch_input_metas = [item.metainfo for item in batch_data_samples]
        img_feats, pts_feats = self.extract_feat(batch_inputs_dict,
                                                 batch_input_metas)
        losses = dict()
        if pts_feats:
            losses_pts = self.pts_bbox_head.loss(pts_feats, batch_data_samples,
                                                 **kwargs)
            losses.update(losses_pts)
        if img_feats:
            losses_img = self.loss_imgs(img_feats, batch_data_samples)
            losses.update(losses_img)
        return losses
        

先进入.extract_feat()看看,按照配置文件,做一些维度转变后,进入到特征提取的函数中extract_img_feat(),extract_pts_feat()。

  def extract_feat(self, batch_inputs_dict: dict,
                     batch_input_metas: List[dict]) -> tuple:
        voxel_dict = batch_inputs_dict.get('voxels', None)
        imgs = batch_inputs_dict.get('imgs', None)
        points = batch_inputs_dict.get('points', None)
        img_feats = self.extract_img_feat(imgs, batch_input_metas)
        pts_feats = self.extract_pts_feat(
            voxel_dict,
            points=points,
            img_feats=img_feats,
            batch_input_metas=batch_input_metas)
        return (img_feats, pts_feats)

以点云提取特征为例,进入到extract_pts_feat()中,可以看到输入分别经过PillarFeatureNet、PointPillarsScatter、SECONDFPN后输出特征图。

   def extract_pts_feat(
            self,
            voxel_dict: Dict[str, Tensor],
            points: Optional[List[Tensor]] = None,
            img_feats: Optional[Sequence[Tensor]] = None,
            batch_input_metas: Optional[List[dict]] = None
    ) -> Sequence[Tensor]:
        if not self.with_pts_bbox:
            return None
        voxel_features = self.pts_voxel_encoder(voxel_dict['voxels'],
                                                voxel_dict['num_points'],
                                                voxel_dict['coors'], img_feats,
                                                batch_input_metas)
        batch_size = voxel_dict['coors'][-1, 0] + 1
        x = self.pts_middle_encoder(voxel_features, voxel_dict['coors'],
                                    batch_size)
        x = self.pts_backbone(x)
        if self.with_pts_neck:
            x = self.pts_neck(x)
        return x

特征提取完毕之后就进行,检测头loss,这次选用的是centerhead。

附录:

config配置文件命名规则

MMDetection3D已经实现的配置文件都位于./configs文件夹下,配置文件都按照统一的规则命名,具体段的含义可以去官方文档链接: https://mmdetection.readthedocs.io/zh_CN/latest/tutorials/config.html#id4自行查阅。

#命名规则
{model}_[model setting]_{backbone}_{neck}_[norm setting]_[misc]_[gpu x batch_per_gpu]_{schedule}_{dataset}

#其中各个字段的含义

    {model}: 模型种类,例如 faster_rcnn, mask_rcnn 等。

    [model setting]: 特定的模型,例如 htc 中的without_semantic, reppoints 中的 moment 等。

    {backbone}: 主干网络种类例如 r50 (ResNet-50), x101 (ResNeXt-101) 等。

    {neck}: Neck 模型的种类包括 fpn, pafpn, nasfpn, c4 等。

    [norm_setting]: 默认使用 bn (Batch Normalization),其他指定可以有 gn (Group Normalization), syncbn (Synchronized Batch Normalization) 等。 gn-head/gn-neck 表示 GN 仅应用于网络的 Head 或 Neck, gn-all 表示 GN 用于整个模型, 例如主干网络、Neck 和 Head。

    [misc]: 模型中各式各样的设置/插件,例如 dconv、 gcb、 attention、albu、 mstrain 等。

    [gpu x batch_per_gpu]:GPU 数量和每个 GPU 的样本数,默认使用 8x2。

    {schedule}: 训练方案,选项是 1x、 2x、 20e 等。1x 和 2x 分别代表 12 epoch 和 24 epoch,20e 在级联模型中使用,表示 20 epoch。对于 1x/2x,初始学习率在第 8/16 和第 11/22 epoch 衰减 10 倍;对于 20e ,初始学习率在第 16 和第 19 epoch 衰减 10 倍。

    {dataset}:数据集,例如 coco、 cityscapes、 voc_0712、 wider_face 等。

参考博客、文档

1、MMEngine理解
2、MMEngine 之 Runner 调用流程浅析
3、mmdetection之config文件
4、深入了解python函数装饰器在mmdetection中的使用(一)
5、带你玩转 3D 检测和分割(一):MMDetection3D 整体框架介绍
6、 轻松掌握 MMDetection 整体构建流程(一)
7、 轻松掌握 MMDetection 整体构建流程(二)

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值