在上一篇博客里分析了Featdepth论文原理和核心源码,也就是模型部分,包括网络结构和损失函数计算:
苹果姐:单目深度估计自监督模型Featdepth解读(上)——论文理解和核心源码分析
本篇博客将介绍Featdepth使用的框架–openMMLab的使用以及作者进行的一些修改和扩展。
Featdepth的源码结构和monodepth2有很大的不同。后者完全是定制化的代码,很适合pytorch入门,前者是使用了商汤的计算机视觉框架OpenMMLab中的基础库mmcv,完全按照mmcv模板写的,在数据读取部分还借鉴了mmdetection的代码,是OpenMMLab中的目标检测库,可以说如果想看懂Featdepth源码结构,必须先学习一下mmcv框架,了解其核心组件Register/Config/Hook/Runner等功能和用法,最好也看看源码。
mmcv工程地址:GitHub - open-mmlab/mmcv: OpenMMLab Computer Vision Foundation
官方文档:Welcome to MMCV’s documentation!
关于OpenMMLab知乎和B站都有博客和视频,我在此只针对Featdepth用到的简要介绍一下。
模型训练部分的代码很短,如下所示:
from __future__ import division
import argparse
from mmcv import Config
from mmcv.runner import load_checkpoint
from mono.datasets.get_dataset import get_dataset
from mono.apis import (train_mono,
init_dist,
get_root_logger,
set_random_seed)
from mono.model.registry import MONO
import torch
def main():
args = parse_args()
print(args.config)
cfg = Config.fromfile(args.config)
cfg.work_dir = args.work_dir
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.gpus = [int(_) for _ in args.gpus.split(',')]
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
print('cfg is ', cfg)
# init logger before other steps
logger = get_root_logger(cfg.log_level)
logger.info('Distributed training: {}'.format(distributed))
# set random seeds
if args.seed is not None:
logger.info('Set random seed to {}'.format(args.seed))
set_random_seed(args.seed)
model_name = cfg.model['name']
model = MONO.module_dict[model_name](cfg.model)
if cfg.resume_from is not None:
load_checkpoint(model, cfg.resume_from, map_location='cpu')
elif cfg.finetune is not None:
print('loading from', cfg.finetune)
checkpoint = torch.load(cfg.finetune, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
train_dataset = get_dataset(cfg.data, training=True)
if cfg.validate:
val_dataset = get_dataset(cfg.data, training=False)
else:
val_dataset = None
train_mono(model,
train_dataset,
val_dataset,
cfg,