OpenMMLab 进阶指南,模型训练测试全流程解析

点击下方卡片,关注“自动驾驶之心”公众号

ADAS巨卷干货,即可获取

点击进入→自动驾驶之心技术交流群

后台回复【ECCV2022】获取ECCV2022所有自动驾驶方向论文!

后台回复【领域综述】获取自动驾驶全栈近80篇综述论文!

后台回复【数据集下载】获取计算机视觉近30种数据集!

892eac607fc36da592402ab08312b3de.gif

大家在上手 OpenMMLab 系列算法库时,会不会有这样一种困惑——我们在配置文件中指定的 ResNet 之类的模型,到底是在哪里执行了训练和测试的 forward?以至于 debug 时不知应当从何下手。

为了帮助大家更好地了解 OpenMMLab 系列算法库的训练和测试中的调用关系,今天我们将从 MMClassification 入手,以较为简单的分类任务为例,帮助大家由浅入深地了解训练测试流程的主干部分,加深对 OpenMMLab 算法库的整体了解。

本文适用于所有的 MMClassification 0.x 版本

MMClassification 是 OpenMMLab 旗下的图像分类任务算法库,不仅提供分类任务基准测试和工具,还致力于提供统一的主干网络(backbone)供其他 OpenMMLab 算法库直接调用。

GitHub 链接:

https://github.com/open-mmlab/mmclassification

关于 OpenMMLab 架构在训练和测试中的抽象,轻松掌握 MMDetection 整体构建流程(二) 一文已做了详细的介绍。

训练与验证流程

在训练开始之前,我们需要编写配置文件。MMClassification 在 configs 文件夹中提供了各种模型常用的样例配置文件,可以直接使用或是稍作修改以用于自己的任务。

完成配置文件的编写之后,我们就可以使用入口脚本 tools/train.py 进行训练和验证。该脚本会进行数据集、模型相关的初始化,并调用高阶 API train_model 来搭建执行器(Runner),模型的训练和验证步骤均由执行器进行调度。

更完整的配置文件教程可见:https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html

这里我们仅以 MMClassification 为基准,介绍从训练入口开始,我们是如何让模型训练起来的,避免大家在 OpenMMLab 架构中迷路,那么让我们出发~

第一站  tools/train.py

正如上文所说,这里是训练和验证的入口脚本。它主要执行的工作是解析命令行参数、环境信息,把这些信息动态更新到配置文件中,做一些诸如打印环境信息、创建工作目录之类的外围操作。除此之外,它还完成了模型和训练数据集的构建。

之后调用高阶 API——train_model 继续我们的训练任务:

def main():
    # 读取命令行参数
    args = parse_args()


    # 读取配置文件
    cfg = Config.fromfile(args.config)
    # 合并 `--cfg-options` 至配置文件
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)


    # 收集并配置运行设备、工作目录、随机种子等信息
    ...


    # 构建模型并初始化权重
    model = build_classifier(cfg.model)
    model.init_weights()


    # 构建数据集
    datasets = [build_dataset(cfg.data.train)]
    ...
    
    # 调用高阶 API train_model 进行模型训练
    train_model(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=(not args.no_validate),
        timestamp=timestamp,
        device=args.device,
        meta=meta)

第二站  train_model

该函数的主要任务是搭建并执行训练执行器,这里我们通过一份流程图来了解它所做的工作:

6bdcaea13cb3bee36ff7d93dfd4f7876.png

在函数的最后,我们使用 runner.run 启动了执行器,由执行器来进行具体的训练。需要额外注意的是:模型的验证并没有使用相同的方式,而是作为执行器的一个钩子,利用 Hook 技术实现模型的验证

三站  runner.run

从这里开始,程序代码转入了 MMCV,许多小伙伴在查阅源码时就会有些困惑,不知道接下来该去哪里跟踪源码,执行器到底调用了模型的哪个接口呢?我想要 debug 该去哪里加断点呢?其实这里并不复杂,让我们一步一步跟踪执行器。

这里我们以分类任务最常用的 EpochBasedRunner 为例进行说明。

以下提到的 runner 也均指 EpochBasedRunner

相关代码可以在 https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/epoch_based_runner.py 中找到

如下图所示,runner.run 方法会逐 epoch 地去调用 runner.train 方法,而 runner.train 又会逐 iteration 地去调用 runner.run_iter 方法。

5647a1535c08d88f46b8e38a8ad5b39f.png

很多人在翻阅执行器源码时会被 run 方法较为复杂的逻辑搞乱,其实其中核心的语句为如下几行:

def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
        ...
        while self.epoch < self._max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))


                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

那么,代码在哪里调用了 runner.train 方法?这还要追溯到我们的配置文件中,在默认的配置文件中都会有这么一行:

workflow = [('train', 1)]

其中第一个元素是 'train' ,对应着代码中的 mode,代码中使用 getattr(self, mode) 的方式取出了执行器的 train 方法。至于相关的 workflow 设计,感兴趣的小伙伴可以看一下 MMCV 核心组件分析(七): Runner,这里我们就不多做介绍,通常也不推荐大家在没有特殊需求的情况下,在分类任务中修改 workflow。

MMCV 核心组件分析(七): Runner:

https://zhuanlan.zhihu.com/p/355272459

总之,我们终于接近了终点,要从执行器中跳回 MMClassification 了。在 runner.run_iter 中,执行器调用了模型的 train_step 方法如下:

outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)

第四站  model.train_step

首先一个问题是,执行器中的 self.model 是哪个类?严谨地说,通常情况下它是 MMDataParallel(MMDP) 或者 MMDistributedDataParallel(MMDDP),因为 train_model 函数对模型进行了封装。但这对于我们理解训练流程并不重要,因为 MMDP 或者 MMDDP 只是一层封装,它们还是会调用所封装模型的 train_step 方法。

那么这个被封装的模型是哪个类呢?其实很简单,在配置文件中,我们的 model 字段通常定义如下,其中 type='ImageClassifier',因此我们主模型是 ImageClassifier 类。

model = dict(
    type='ImageClassifier',
    backbone=...,
    neck=...,
    head=...,
    ))

通常,主模型和算法本身的架构相关。如检测任务中,根据算法的不同,主模型可以是 RetinaNetYOLOX 这样的算法。但在分类任务中,由于 MMClassification 目前还仅支持单标签和多标签的监督学习,这些算法基本都遵循着 “主干网络+可选的 GAP +分类头” 的总体结构,因而我们只有 ImageClassifier 这么一个主模型,期待将来 MMClassficiation 支持更多的任务吧~

在进入 ImageClassfier.train_step(该方法定义在基类 BaseClassifier 中) 之后,我们发现,train_step 依然是一个“中间商”,它调用了模型的 forward 方法,并指定 return_loss=True,进而调用模型的 forward_train 方法。

def train_step(self, data, optimizer=None, **kwargs):
        """mmcls/models/classifiers/base.py"""
        losses = self(**data)   # --> forward
        loss, log_vars = self._parse_losses(losses)


        outputs = dict(
            loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))


        return outputs
        
    def forward(self, img, return_loss=True, **kwargs):
        """mmcls/models/classifiers/base.py"""
        if return_loss:
            return self.forward_train(img, **kwargs)
        else:
            return self.forward_test(img, **kwargs)
            
    def forward_train(self, img, gt_label, **kwargs):
        """mmcls/models/classifiers/image.py"""
        if self.augments is not None:
            img, gt_label = self.augments(img, gt_label)


        # 调用 backbone 和 neck 的 forward
        x = self.extract_feat(img)


        losses = dict()
        # 在 head 中计算 loss
        loss = self.head.forward_train(x, gt_label)


        losses.update(loss)


        return losses

是否有些混乱了?其实简单来说,因为我们将损失函数定义在了分类头中,在训练时我们希望分类头返回损失函数,在验证或测试时我们希望分类头返回各类得分,因此通过 forward 方法和 return_loss 参数来做中间的分发,实际在训练中走的是模型的 forward_train 方法,在这里,数据终于历尽千辛万苦,进入了主干网络、分类头等模型结构中。

测试流程

相较于训练流程,模型的测试流程就简单很多了。这里没有再使用执行器,而是直接在高级 API single_gpu_test 或是 multi_gpu_test 中调用模型进行测试。具体流程如下:

  1. 在入口脚本 tools/test.py 中,我们完成了命令参数的解析、数据集及 data loader 的构建、模型的构建及封装,并调用 single_gpu_test 或是 multi_gpu_test 获取测试结果

  2. single_gpu_test 或是 multi_gpu_test 中,我们遍历整个 data loader 中的数据,调用模型的 forward 方法,并传入参数 return_loss=False。在上一节中我们已经提到了,模型的 forward 方法会根据 return_loss 参数执行模型的不同分支,当 return_loss=False 时,会调用模型的 forward_test 函数,去获得模型预测结果,而不是损失函数。

  3. forward_test 函数的源码如下。虽然目前 MMClassification 还不支持 TTA(Test-Time Augmentation),但为了保持 OpenMMLab 各算法库风格统一,这里对输入参数 imgs 做了许多额外的判断。在目前 MMClassification 的测试流程中,imgs 参数只会是一个 batch 的图像,即一个形状为 (N, C, H, W) 的 Tensor。因此目前我们可以简单地认为 forward_test 进一步调用了模型的 simple_test 方法

def forward_test(self, imgs, **kwargs):
        """
        Args:
            imgs (Tensor | List[Tensor]): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxCxHxW,
                which contains all images in the batch.
        """
        if isinstance(imgs, torch.Tensor):
            imgs = [imgs]
        for var, name in [(imgs, 'imgs')]:
            if not isinstance(var, list):
                raise TypeError(f'{name} must be a list, but got {type(var)}')


        if len(imgs) == 1:
            return self.simple_test(imgs[0], **kwargs)
        else:
            raise NotImplementedError('aug_test has not been implemented')
  1. 终于,我们获得模型在整个数据集中的推理结果,返回到了 tools/test.py 中。之后,我们会调用数据集的 evalutate 方法,将数据集的推理结果传递进去,由 evaluate 方法来处理各种评价指标的计算

结语

本文我们详细梳理了训练和测试过程中,从入口脚本到模型实际计算接口的全流程,以及中间每一步所做的操作。希望这些内容能够帮助大家理清模型的调用栈,遇到问题时能快速定位到是在哪一层级出了问题,调整训练测试行为时知道应该在哪一层级去做修改。

对于 MMClassification 与 OpenMMLab 系列算法库的整体结构,如果大家有更多希望了解的部分,欢迎留言告诉我们!感谢大家的支持~

MMClassification直达:

https://github.com/open-mmlab/mmclassification

自动驾驶之心】全栈技术交流群

自动驾驶之心是首个自动驾驶开发者社区,聚焦目标检测、语义分割、全景分割、实例分割、关键点检测、车道线、目标跟踪、3D感知、多传感器融合、SLAM、高精地图、规划控制、AI模型部署落地等方向;

加入我们:自动驾驶之心技术交流群汇总!

自动驾驶之心【知识星球】

想要了解更多自动驾驶感知(分类、检测、分割、关键点、车道线、3D感知、多传感器融合、目标跟踪)、自动驾驶定位建图(SLAM、高精地图)、自动驾驶规划控制、领域技术方案、AI模型部署落地实战、行业动态、岗位发布,欢迎扫描下方二维码,加入自动驾驶之心知识星球(三天内无条件退款),日常分享论文+代码,这里汇聚行业和学术界大佬,前沿技术方向尽在掌握中,期待交流!

ab2426c375768e4c34ae6301ad13e525.jpeg

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值