单目深度估计自监督模型Featdepth解读(下)——openMMLab框架使用

在上一篇博客里分析了Featdepth论文原理和核心源码,也就是模型部分,包括网络结构和损失函数计算:

苹果姐:单目深度估计自监督模型Featdepth解读(上)——论文理解和核心源码分析

本篇博客将介绍Featdepth使用的框架–openMMLab的使用以及作者进行的一些修改和扩展。
在这里插入图片描述
Featdepth的源码结构和monodepth2有很大的不同。后者完全是定制化的代码,很适合pytorch入门,前者是使用了商汤的计算机视觉框架OpenMMLab中的基础库mmcv,完全按照mmcv模板写的,在数据读取部分还借鉴了mmdetection的代码,是OpenMMLab中的目标检测库,可以说如果想看懂Featdepth源码结构,必须先学习一下mmcv框架,了解其核心组件Register/Config/Hook/Runner等功能和用法,最好也看看源码。

mmcv工程地址:GitHub - open-mmlab/mmcv: OpenMMLab Computer Vision Foundation

官方文档:Welcome to MMCV’s documentation!

关于OpenMMLab知乎和B站都有博客和视频,我在此只针对Featdepth用到的简要介绍一下。

模型训练部分的代码很短,如下所示:

from __future__ import division

import argparse
from mmcv import Config
from mmcv.runner import load_checkpoint

from mono.datasets.get_dataset import get_dataset
from mono.apis import (train_mono,
                       init_dist,
                       get_root_logger,
                       set_random_seed)
from mono.model.registry import MONO
import torch


def main():
    args = parse_args()
    print(args.config)
    cfg = Config.fromfile(args.config)
    cfg.work_dir = args.work_dir

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    cfg.gpus = [int(_) for _ in args.gpus.split(',')]

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    print('cfg is ', cfg)
    # init logger before other steps
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    model_name = cfg.model['name']
    model = MONO.module_dict[model_name](cfg.model)

    if cfg.resume_from is not None:
        load_checkpoint(model, cfg.resume_from, map_location='cpu')
    elif cfg.finetune is not None:
        print('loading from', cfg.finetune)
        checkpoint = torch.load(cfg.finetune, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    train_dataset = get_dataset(cfg.data, training=True)
    if cfg.validate:
        val_dataset = get_dataset(cfg.data, training=False)
    else:
        val_dataset = None

    train_mono(model,
               train_dataset,
               val_dataset,
               cfg,
      
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值