OpenMMLab 进阶指南,模型训练测试全流程解析

目录

训练与验证流程

第一站,tools/train.py

第二站,train_model

第三站,runner.run

第四站,model.train_step

测试流程

结语


大家在上手 OpenMMLab 系列算法库时,会不会有这样一种困惑——我们在配置文件中指定的 ResNet 之类的模型,到底是在哪里执行了训练和测试的 forward?以至于 debug 时不知应当从何下手。

为了帮助大家更好地了解 OpenMMLab 系列算法库的训练和测试中的调用关系,今天我们将从 MMClassification 入手,以较为简单的分类任务为例,帮助大家由浅入深的了解训练测试流程的主干部分,加深对 OpenMMLab 算法库的整体了解。

本文适用于所有的 MMClassification 0.x 版本

MMClassification 是 OpenMMLab 旗下的图像分类任务算法库,不仅提供分类任务基准测试和工具,还致力于提供统一的主干网络(backbone)供其他 OpenMMLab 算法库直接调用。

https://github.com/open-mmlab/mmclassification​github.com/open-mmlab/mmclassification


关于 OpenMMLab 架构在训练和测试中的抽象,轻松掌握 MMDetection 整体构建流程(二) 一文做了详细的介绍。

训练与验证流程


在训练开始之前,我们需要编写配置文件。MMClassification 在 configs 文件夹中提供了各种模型常用的样例配置文件,可以直接使用或是稍作修改以用于自己的任务。

完成配置文件的编写之后,我们就可以使用入口脚本 tools/train.py 进行训练和验证。该脚本会进行数据集、模型相关的初始化,并调用高阶 API train_model 来搭建执行器(Runner),模型的训练和验证步骤均由执行器进行调度。

更完整的配置文件教程可见: https://mmclassification.readthedocs.io/zh_CN/latest/tutorials/config.html

这里我们仅以 MMClassification 为基准,介绍从训练入口开始,我们是如何让模型训练起来的,避免大家在 OpenMMLab 架构中迷路,那么让我们出发~


第一站,tools/train.py


正如上文所说,这里是训练和验证的入口脚本。它主要执行的工作是解析命令行参数、环境信息,把这些信息动态更新到配置文件中,做一些诸如打印环境信息、创建工作目录之类的外围操作。除此之外,它还完成了模型和训练数据集的构建。

之后调用高阶 API——train_model 继续我们的训练任务:

def main(): 
    # 读取命令行参数 
    args = parse_args() 
 
    # 读取配置文件 
    cfg = Config.fromfile(args.config) 
    # 合并 `--cfg-options` 至配置文件 
    if args.cfg_options is not None: 
        cfg.merge_from_dict(args.cfg_options) 
 
    # 收集并配置运行设备、工作目录、随机种子等信息 
    ... 
 
    # 构建模型并初始化权重 
    model = build_classifier(cfg.model) 
    model.init_weights() 
 
    # 构建数据集 
    datasets = [build_dataset(cfg.data.train)] 
    ... 
 
    # 调用高阶 API train_model 进行模型训练 
    train_model( 
        model, 
        datasets, 
        cfg, 
        distributed=distributed, 
        validate=(not args.no_validate), 
        timestamp=timestamp, 
        device=args.device, 
        meta=meta) 


第二站,train_model


该函数的主要任务是搭建并执行训练执行器,这里我们通过一份流程图来了解它所做的工作:


在函数的最后,我们使用 runner.run 启动了执行器,由执行器来进行具体的训练。需要额外注意的是:模型的验证并没有使用相同的方式,而是作为执行器的一个钩子,利用 Hook 技术实现模型的验证。


第三站,runner.run


从这里开始,程序代码转入了 MMCV,许多同学在查阅源码时就会有些困惑,不知道接下来该去哪里跟踪源码,执行器到底调用了模型的哪个接口呢?我想要 debug 该去哪里加断点呢?其实这里并不复杂,让我们一步一步跟踪执行器。

这里我们以分类任务最常用的 EpochBasedRunner 为例进行说明。

以下提到的 runner 也均指 EpochBasedRunner,相关代码可以在  https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/epoch_based_runner.py 中找到


如下图所示,runner.run 方法会逐 epoch 地去调用 runner.train 方法,而 runner.train 又会逐 iteration 地去调用 runner.run_iter 方法。


很多人在翻阅执行器源码时会被 run 方法较为复杂的逻辑搞乱,其实其中核心的语句为如下几行:

     def run(self, data_loaders, workflow, max_epochs=None, **kwargs): 
        ... 
        while self.epoch < self._max_epochs: 
            for i, flow in enumerate(workflow): 
                mode, epochs = flow 
                if isinstance(mode, str):  # self.train() 
                    if not hasattr(self, mode): 
                        raise ValueError( 
                            f'runner has no method named "{mode}" to run an ' 
                            'epoch') 
                    epoch_runner = getattr(self, mode) 
                else: 
                    raise TypeError( 
                        'mode in workflow must be a str, but got {}'.format( 
                            type(mode))) 
 
                for _ in range(epochs): 
                    if mode == 'train' and self.epoch >= self._max_epochs: 
                        break 
                    epoch_runner(data_loaders[i], **kwargs) 



那么,代码在哪里调用了 runner.train 方法?这还要追溯到我们的配置文件中,在默认的配置文件中都会有这么一行:

workflow = [('train', 1)] 



其中第一个元素是 'train' ,对应着代码中的 mode,代码中使用 getattr(self, mode) 的方式取出了执行器的 train 方法。至于相关的 workflow 设计,有兴趣的同学可以看一下 MMCV 核心组件分析(七): Runner,这里我们就不多做介绍,通常也不推荐大家在没有特殊需求的情况下,在分类任务中修改 workflow。

总之,我们终于接近了终点,要从执行器中跳回 MMClassification 了。在 runner.run_iter 中,执行器调用了模型的 train_step 方法如下:

outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) 



第四站,model.train_step


首先一个问题是,执行器中的 self.model 是哪个类?严谨地说,通常情况下它是 MMDataParallel(MMDP) 或者 MMDistributedDataParallel(MMDDP),因为 train_model 函数对模型进行了封装。但这对于我们理解训练流程并不重要,因为 MMDP 或者 MMDDP 只是一层封装,它们还是会调用所封装模型的 train_step 方法。

那么这个被封装的模型是哪个类呢?其实很简单,在配置文件中,我们的 model 字段通常定义如下,其中 type='ImageClassifier',因此我们主模型是 ImageClassifier 类。

model = dict( 
    type='ImageClassifier', 
    backbone=..., 
    neck=..., 
    head=..., 
    )) 



通常,主模型和算法本身的架构相关。如检测任务中,根据算法的不同,主模型可以是 RetinaNetYOLOX 这样的算法。但在分类任务中,由于 MMClassification 目前还仅支持单标签和多标签的监督学习,这些算法基本都遵循着 “主干网络+可选的 GAP +分类头” 的总体结构,因而我们只有 ImageClassifier 这么一个主模型,期待将来 MMClassficiation 支持更多的任务吧~

在进入 ImageClassfier.train_step(该方法定义在基类 BaseClassifier 中) 之后,我们发现,train_step 依然是一个“中间商”,它调用了模型的 forward 方法,并指定 return_loss=True,进而调用模型的 forward_train 方法。

    def train_step(self, data, optimizer=None, **kwargs): 
        """mmcls/models/classifiers/base.py""" 
        losses = self(**data)   # --> forward 
        loss, log_vars = self._parse_losses(losses) 
 
        outputs = dict( 
            loss=loss, log_vars=log_vars, num_samples=len(data['img'].data)) 
 
        return outputs 
 
    def forward(self, img, return_loss=True, **kwargs): 
        """mmcls/models/classifiers/base.py""" 
        if return_loss: 
            return self.forward_train(img, **kwargs) 
        else: 
            return self.forward_test(img, **kwargs) 
 
    def forward_train(self, img, gt_label, **kwargs): 
        """mmcls/models/classifiers/image.py""" 
        if self.augments is not None: 
            img, gt_label = self.augments(img, gt_label) 
 
        # 调用 backbone 和 neck 的 forward 
        x = self.extract_feat(img) 
 
        losses = dict() 
        # 在 head 中计算 loss 
        loss = self.head.forward_train(x, gt_label) 
 
        losses.update(loss) 
 
        return losses 



是否有些混乱了?其实简单来说,因为我们将损失函数定义在了分类头中,在训练时我们希望分类头返回损失函数,在验证或测试时我们希望分类头返回各类得分,因此通过 forward 方法和 return_loss 参数来做中间的分发,实际在训练中走的是模型的 forward_train 方法,在这里,数据终于历尽千辛万苦,进入了主干网络、分类头等模型结构中。


测试流程


相较于训练流程,模型的测试流程就简单很多了。这里没有再使用执行器,而是直接在高级 API single_gpu_test 或是 multi_gpu_test 中调用模型进行测试。具体流程如下:

1.在入口脚本 tools/test.py 中,我们完成了命令参数的解析、数据集及 data loader 的构建、模型的构建及封装,并调用 single_gpu_test 或是 multi_gpu_test 获取测试结果

2.在 single_gpu_test 或是 multi_gpu_test 中,我们遍历整个 data loader 中的数据,调用模型的 forward 方法,并传入参数 return_loss=False。在上一节中我们已经提到了,模型的 forward 方法会根据 return_loss 参数执行模型的不同分支,当 return_loss=False 时,会调用模型的 forward_test 函数,去获得模型预测结果,而不是损失函数。

3.forward_test 函数的源码如下。虽然目前 MMClassification 还不支持 TTA(Test-Time Augmentation),但为了保持 OpenMMLab 各算法库风格统一,这里对输入参数 imgs 做了许多额外的判断。在目前 MMClassification 的测试流程中,imgs 参数只会是一个 batch 的图像,即一个形状为 (N, C, H, W) 的 Tensor。因此目前我们可以简单地认为 forward_test 进一步调用了模型的 simple_test 方法

    def forward_test(self, imgs, **kwargs): 
        """ 
        Args: 
            imgs (Tensor | List[Tensor]): the outer list indicates test-time 
                augmentations and inner Tensor should have a shape NxCxHxW, 
                which contains all images in the batch. 
        """ 
        if isinstance(imgs, torch.Tensor): 
            imgs = [imgs] 
        for var, name in [(imgs, 'imgs')]: 
            if not isinstance(var, list): 
                raise TypeError(f'{name} must be a list, but got {type(var)}') 
 
        if len(imgs) == 1: 
            return self.simple_test(imgs[0], **kwargs) 
        else: 
            raise NotImplementedError('aug_test has not been implemented') 

4.终于,我们获得模型在整个数据集中的推理结果,返回到了 tools/test.py 中。之后,我们会调用数据集的 evalutate 方法,将数据集的推理结果传递进去,由 evaluate 方法来处理各种评价指标的计算。

结语

本文我们详细梳理了训练和测试过程中,从入口脚本到模型实际计算接口的全流程,以及中间每一步所做的操作。希望这些内容能够帮助大家理清模型的调用栈,遇到问题时能快速定位到是在哪一层级出了问题,调整训练测试行为时知道应该在哪一层级去做修改。

对于 MMClassification 与 OpenMMLab 系列算法库的整体结构,如果大家有更多希望了解的部分,欢迎留言告诉我们!感谢大家的支持~

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值