【mmdetection源码阅读一】train.py

前言

研究方向是目标检测,最近找到该开源库,测试过其性能还是不错的,其文章链接为点击下载,github链接为点击查看,注意新版的代码链接,已经支持到了Pytorch1.5,但是截止现在,最新版的代码编译会有个错误,这个还是个开放问题,还没有解决,所以我这里用的是我前两个月下载的一个旧版本,支持Pytorch1.1+的代码库,我昨天使用Pytorch1.5编译失败也是因为以下原因,所以,我使用了Pytorch1.3+CUDA10.1配置的环境。

在这里插入图片描述

源码解读

这里从train.py开始介绍,一点点的阅读源码,有助于后续的代码改进
train.py里面主要包含了两个函数parse_args()和main()

parse_args()函数,在这个配置函数里面,大部分的配置信息我们都是可以理解的,其实比较困难的是关于分布式训练的一些内容,推荐大家阅读下这里
# 这个方法主要是配置一些实验配置参数
def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')#加载配置文件
    parser.add_argument('--work_dir', help='the dir to save logs and models')#运行结果储存的位置
    parser.add_argument(
        '--resume_from', help='the checkpoint file to resume from')#断点续训的复训文件夹
    parser.add_argument(
        '--validate',
        action='store_true',
        help='whether to evaluate the checkpoint during training')#是否开启验证
    parser.add_argument(
        '--gpus',
        type=int,
        default=1,
        help='number of gpus to use '
        '(only applicable to non-distributed training)') #声明GPU的数目
    parser.add_argument('--seed', type=int, default=None, help='random seed')#设置随机种子
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')# 是否为CUDNN后端设置确定性选项
    # 采用哪种分布式训练模式 torch.distributed.init_process_group()
    # 分布式多进程初始化时使用
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher') 
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument(
        '--autoscale-lr',
        action='store_true',
        help='automatically scale lr with the number of gpus')# 根据GPU数目自动更改学习率
    args = parser.parse_args()
    #获取系统环境变量 声明分布式训练的本地序号
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args
main()函数,关于参数torch.backends.cudnn.benchmark,辅助阅读进入
def main():
    args = parse_args()#获得命令行参数,实际上就是获取config配置文件
    #读取配置文件
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    #在图片输入尺度固定时开启,可以加速,一般都是关的,只有在固定尺度的网络如SSD512中才开启
    #为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    #更新一些配置参数
    # 创建工作目录存放训练文件,如果不键入,会自动从py配置文件中生成对应的目录,key为work_dir
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    # 断点继续训练的权值文件,为None就没有这一步的设置
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    cfg.gpus = args.gpus# gpu数目
    # 线性学习率
    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8

    # init distributed env first, since logger depends on the dist info.
    # 单机训练 默认是none 不使用分布式训练
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir 创建文件夹 使用的os.mkdirs()可同时创建多级目录
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # init the logger before other steps 
    # 初始化一些时间戳,得到一些根日志
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
    #log_level在配置文件里有这个key,value=“INFO”训练一次batch就可以看到输出这个str
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    #得到一些实验环境配置
    env_info_dict = collect_env()
    env_info = '\n'.join([('{}: {}'.format(k, v))
                          for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info

    # log some basic info
    logger.info('Distributed training: {}'.format(distributed))
    logger.info('Config:\n{}'.format(cfg.text))

    # set random seeds
    #设置随机种子,便于实验复现
    # 默认为None
    if args.seed is not None:
        logger.info('Set random seed to {}, deterministic: {}'.format(
            args.seed, args.deterministic))
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
    # 加载模型
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    #加载数据集
    datasets = [build_dataset(cfg.data.train)]
    #如果该列表长度为2,则追加验证数据集
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    #判断模型配置是否为空
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__,
            config=cfg.text,
            CLASSES=datasets[0].CLASSES)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    #构建训练器
    train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=args.validate,
        timestamp=timestamp,
        meta=meta)
以上为个人见解,如有问题,欢迎指正!
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值