单目深度估计自监督模型Featdepth解读(下)——openMMLab框架使用

在上一篇博客里分析了Featdepth论文原理和核心源码,也就是模型部分,包括网络结构和损失函数计算:

苹果姐:单目深度估计自监督模型Featdepth解读(上)——论文理解和核心源码分析

本篇博客将介绍Featdepth使用的框架–openMMLab的使用以及作者进行的一些修改和扩展。
在这里插入图片描述
Featdepth的源码结构和monodepth2有很大的不同。后者完全是定制化的代码,很适合pytorch入门,前者是使用了商汤的计算机视觉框架OpenMMLab中的基础库mmcv,完全按照mmcv模板写的,在数据读取部分还借鉴了mmdetection的代码,是OpenMMLab中的目标检测库,可以说如果想看懂Featdepth源码结构,必须先学习一下mmcv框架,了解其核心组件Register/Config/Hook/Runner等功能和用法,最好也看看源码。

mmcv工程地址:GitHub - open-mmlab/mmcv: OpenMMLab Computer Vision Foundation

官方文档:Welcome to MMCV’s documentation!

关于OpenMMLab知乎和B站都有博客和视频,我在此只针对Featdepth用到的简要介绍一下。

模型训练部分的代码很短,如下所示:

from __future__ import division

import argparse
from mmcv import Config
from mmcv.runner import load_checkpoint

from mono.datasets.get_dataset import get_dataset
from mono.apis import (train_mono,
                       init_dist,
                       get_root_logger,
                       set_random_seed)
from mono.model.registry import MONO
import torch


def main():
    args = parse_args()
    print(args.config)
    cfg = Config.fromfile(args.config)
    cfg.work_dir = args.work_dir

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    cfg.gpus = [int(_) for _ in args.gpus.split(',')]

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    print('cfg is ', cfg)
    # init logger before other steps
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    model_name = cfg.model['name']
    model = MONO.module_dict[model_name](cfg.model)

    if cfg.resume_from is not None:
        load_checkpoint(model, cfg.resume_from, map_location='cpu')
    elif cfg.finetune is not None:
        print('loading from', cfg.finetune)
        checkpoint = torch.load(cfg.finetune, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    train_dataset = get_dataset(cfg.data, training=True)
    if cfg.validate:
        val_dataset = get_dataset(cfg.data, training=False)
    else:
        val_dataset = None

    train_mono(model,
               train_dataset,
               val_dataset,
               cfg,
               distributed=distributed,
               validate=cfg.validate,
               logger=logger)

首先Config组件用来把各种格式的配置文件读取成Config对象,以便于读取

然后指定了预训练模型地址和GPU、cuda设置、是否使用分布式多卡、日志设置、随机种子等

model_name = cfg.model['name']
model = MONO.module_dict[model_name](cfg.model)

这两句是从已注册的模型字典中取出你想要训练的模型对象。熟悉设计模式的朋友们会发现这是工厂模式的典型用法:通过Register工具将各个模型注册进了工厂的module_dict里,根据配置项中的字符串取出相应的模型类并实例化。具体用法如下:

在from mono.model.registry import MONO这一句中,先是执行了mono/model/init_.py,导入了源码中的四个模型:

from .mono_baseline.net import Baseline
from .mono_autoencoder.net import autoencoder
from .mono_fm.net import mono_fm
from .mono_fm_joint.net import mono_fm_joint

每个net文件被导入时都自动运行了Registry类中的装饰函数:(装饰函数在被修饰的函数或类定义的时候就自动被调用,有关装饰函数的用法不熟悉的话可参见 Python @函数装饰器及用法(超级详细)

@MONO.register_module
class mono_fm_joint(nn.Module):

这里的Registry类是作者自己写的,是源码的简化版,只保留了注册module_dict的功能,也不支持传参。可能因为源码的参数中有个build_func,默认为build_from_cfg,是用来将类实例化的方法,作者省掉了这个方法,所以重新写了一个不需要传参的Registry。MONO是他首先实例化的一个Registry对象,也就是工厂,用来保存module_dict。

import torch
import torch.nn as nn

class Registry(object):
    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def _register_module(self, module_class):
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not issubclass(module_class, nn.Module):
            raise TypeError(
                'module must be a child of nn.Module, but got {}'.format(
                    module_class))
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class

    def register_module(self, cls):       # 作为装饰函数
        self._register_module(cls)
        return cls

MONO = Registry('mono')

后面是预训练模型载入或者断点恢复

然后就是数据读取操作,这里的dataset和monodepth2类似,都是继承自pytorch自带的torch.utils.data.Dataset,进行了相应的扩展。

然后就是train_mono函数了,这是训练的核心功能。

def train_mono(model,
               dataset_train,
               dataset_val,
               cfg,
               distributed=False,
               validate=False,
               logger=None):
    if logger is None:
        logger = get_root_logger(cfg.log_level)

    # start training
    if distributed:
        _dist_train(model, dataset_train, dataset_val, cfg, validate=validate)
    else:
        _non_dist_train(model, dataset_train, dataset_val, cfg, validate=validate)

可以看出训练分为分布式和非分布式。非分布式核心代码如下:

def _non_dist_train(model, dataset_train, dataset_val, cfg, validate=False):
    # prepare data loaders
    data_loaders = [
        build_dataloader(dataset_train,
                         cfg.imgs_per_gpu,
                         cfg.workers_per_gpu,
                         cfg.gpus.__len__(),
                         dist=False)
    ]
    # put model on gpus
    model = MMDataParallel(model, device_ids=cfg.gpus).cuda()
    # build runner
    optimizer = build_optimizer(model,
                                cfg.optimizer)
    runner = Runner(model, batch_processor,
                    optimizer,
                    cfg.work_dir,
                    cfg.log_level)
    runner.register_training_hooks(cfg.lr_config,
                                   cfg.optimizer_config,
                                   cfg.checkpoint_config,
                                   cfg.log_config)

分布式核心代码如下:

def _dist_train(model, dataset_train, dataset_val, cfg, validate=False):
    # prepare data loaders
    data_loaders = [build_dataloader(dataset_train,
                                     cfg.imgs_per_gpu,
                                     cfg.workers_per_gpu,
                                     dist=True)
    ]
    # put model on gpus
    model = MMDistributedDataParallel(model.cuda())
    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    print('cfg work dir is ', cfg.work_dir)
    runner = Runner(model,
                    batch_processor,
                    optimizer,
                    cfg.work_dir,
                    cfg.log_level)
    # register hooks
    optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
    runner.register_training_hooks(cfg.lr_config,
                                   optimizer_config,
                                   cfg.checkpoint_config,
                                   cfg.log_config)
    runner.register_hook(DistSamplerSeedHook())

这几乎就是用了mmcv的模板,从数据读取build_dataloader,到模型包装MMDataParallel/MMDistributedDataParallel、优化器build_optimizer、训练工作流Runner,以及各种hook注册。

这里详细解释一下各个步骤。

1.data_loader部分

一般来说data_loader可以使用pytorch原生的torch.utils.data.DataLoader,但是其中的sampler参数是可以自定义的。pytorch原生且默认的dataloader有两种:RandomSampler和SequentialSampler,以及以batch为单位的BatchSampler,分布式工具torch.utils.data.distributed中提供的是DistributedSampler,但原生的有几个缺点:

一是全都缺少分组功能,假如你的数据集是分为几类的,一个batch输入的图片必须属于同一类,就需要扩展

二是DistributedSampler缺少shuffle功能,假如想随机输入,也需要扩展

三是DistributedSampler只提供多卡数据补全功能,也就是说保证你的图片总数可以被gpu数整除,确保每个gpu有同样数量的图片,但缺少batchsize补全功能,也就是保证每个gpu的图片数量可以被batchsize整除(单卡训练的RandomSampler和SequentialSampler也没有),这时候只能在data_loader初始化时设置drop_last=True,通过去掉最后一个batch来保证每个batch的大小一样(这在有些模型中比较重要,因为可能网络的输入大小限制为batch_size,出现不能整除的情况最后一个batch会报错),这样会导致浪费部分图片,也需要扩展。

作者在这里用了mmdetection中写的三种sampler扩展类源码,其实并没有用到分组功能,只用到了shuffle和batch_size补全功能,分别是:

单卡训练采用GroupSampler,实现分组功能

多卡训练使用了DistributedGroupSampler和加入shuffle功能的DistributedSampler,实现分组和数据补全功能。这个源码就不贴了,看起来比较费劲,在notebook上构造数据跑了一下才理解。

2.模型并行化部分

这里单卡情况使用了MMDataParallel,多卡情况使用了MMDistributedDataParallel。

MMDataParallel的注释中描述了和pytorch的DataParallel的区别:

1.支持一个定制化类型:DataContainer,可以允许对输入数据更加灵活的控制。DataContainer的解释是可以解决原生DataParallel对数据大小必须一致、类型必须一致的限制。

2.支持两个API:train_step()和val_step(),这是mmcv中的工作流控制工具Runner中需要的方法。但MMDataParallel只支持单卡训练,多卡训练要使用MMDistributedDataParallel。MMDistributedDataParallel对pytorch原生的DistributedDataParallel的扩展和MMDataParallel一致。

3.优化器部分

主要通过build_optimizer函数读取配置文件中的优化器配置,从torch.optim中寻找相应的优化器类并实例化。这部分是featdepth作者自己写的简化版,直接用了pytorch自带的,在mmcv源码中支持自定义优化器,也通过Registry来注册和实例化。

4.工作流Runner部分

Runner是mmcv训练部分的引擎,整个训练的流程都由它来控制。详细解释可以参照OpenMMLab:MMCV 核心组件分析(七): Runner

Runner根据Epoch 和 Iter 模式又分为EpochBasedRunner和IterBasedRunner,默认Runner是EpochBasedRunner。需要以下参数进行初始化:

def __init__(self,
                 model,
                 batch_processor=None,
                 optimizer=None,
                 work_dir=None,
                 logger=None,
                 meta=None,
                 max_iters=None,
                 max_epochs=None):

实际使用的时候调用Runner.run():

def run(self, 
    data_loaders, # dataloader 列表
    workflow,  # 工作流列表,长度需要和 data_loaders 一致
    max_epochs=None, 
    **kwargs):

其中workflow参数决定了工作流的顺序,例如workflow = [(‘train’, 1),(‘val’, 1)],代表一个train一个val流程。run是Runner的入口,会从workflow中读取工作流名称,用getattr()去调用对应的方法。例如调用train()方法:

def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs)
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

其中run_iter()函数用来走train流程中的前向传播部分,主要包括计算输出、计算损失函数。

def run_iter(self, data_batch, train_mode, **kwargs):
        if self.batch_processor is not None:
            outputs = self.batch_processor(
                self.model, data_batch, train_mode=train_mode, **kwargs)
        elif train_mode:
            outputs = self.model.train_step(data_batch, self.optimizer,
                                            **kwargs)
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('"batch_processor()" or "model.train_step()"'
                            'and "model.val_step()" must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.outputs = outputs

解释:如果自定义了batch_processor方法,则调用batch_processor中的流程,否则调用model中的train_step()(此时model已被MMDataParallel或者MMDistributedDataParallel包装)。在train_step函数内部又进一步调用了model本身的train_step,如果不定义batch_processor,需要在模型中定义这个函数。例如:

class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def train_step(self, data, optimizer):
        images, labels = data
        predicts = self(images)  # -> self.__call__() -> self.forward()
        loss = self.loss_fn(predicts, labels)
        return {'loss': loss}

5.HOOK部分

HOOK部分具体介绍也可以参照博客:OpenMMLab:MMCV 核心组件分析(六): Hook

mmcv框架大量用到了HOOK,即除了主流程之外,其他所有的功能都通过HOOK去调用:

self.call_hook(‘after_run’)
框架支持默认HOOK、定制HOOK和自定义HOOK,默认HOOK可以使用runner.register_training_hooks()直接注册,定制HOOK可以从runner中导入之后用runner.registerhook()进行注册,如featdepth中使用的DistSamplerSeedHook,再有一种自定义HOOK,可以由用户自定义编写HOOK内容,就需要继承runner中的HOOK类,定义代表位置的函数和优先级,如before_train_epoch()、beforerun()等,在上层函数调用call_hook()的时候就会按照优先级去依次调用已注册的HOOK中的相应函数。例如:

class DistEvalHook(Hook):
    def __init__(self, dataset, interval=1, cfg=None):
        assert isinstance(dataset, Dataset)
        self.dataset = dataset
        self.interval = interval
        self.cfg = cfg

    def after_train_epoch(self, runner):
        print('evaluation..............................................')

        if not self.every_n_epochs(runner, self.interval):
            return
        runner.model.eval()
        results = [None for _ in range(len(self.dataset))]
        if runner.rank == 0:
            prog_bar = mmcv.ProgressBar(len(self.dataset))

框架部分就介绍到这里,水平有限,欢迎指正

要了解模型的原理和核心代码,请继续阅读:

苹果姐:单目深度估计自监督模型Featdepth解读(上)——论文理解和核心源码分析

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值