训练准备工作(一)

def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority:
    # config file > default (base filename)
    if cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])

    if not hasattr(cfg, 'dist_params'):
        cfg.dist_params = dict(backend='nccl')

这段代码是一个 PyTorch 项目的主函数部分,它展示了训练过程中一些常见的配置设置。逐一解释一下这段代码的主要作用和要点:

  1. parse_args() 函数用于解析命令行参数,比如指定配置文件的路径等。

  2. Config.fromfile(args.config) 用于从配置文件中加载模型的各种超参数和设置,如网络结构、优化器、损失函数等。

  3. torch.backends.cudnn.benchmark = True 可以加速 cuDNN 的自动调优过程,在输入数据尺寸固定的情况下提高训练效率。

  4. cfg.work_dir 用于指定模型的保存路径。如果配置文件中没有指定,则默认使用配置文件名作为工作目录。

  5. cfg.dist_params 设置了分布式训练的后端为 NCCL,这是一种常用的分布式训练协议。

总的来说,这段代码主要完成了以下几个任务:

  1. 读取命令行参数和配置文件
  2. 设置 cuDNN 自动调优
  3. 确定模型保存路径
  4. 配置分布式训练

这些准备工作为后续的模型训练和评估提供了良好的基础。整体来说,这段代码展现了一个较为标准的 PyTorch 训练流程。

    init_dist(args.launcher, **cfg.dist_params)
    rank, world_size = get_dist_info()
    cfg.gpu_ids = range(world_size)

    auto_resume = cfg.get('auto_resume', True)
    if auto_resume and cfg.get('resume_from', None) is None:
        resume_pth = osp.join(cfg.work_dir, 'latest.pth')
        if osp.exists(resume_pth):
            cfg.resume_from = resume_pth
  1. init_dist(args.launcher, **cfg.dist_params) 函数用于初始化分布式训练环境。根据配置文件中的 dist_params 设置,如后端类型、进程数量等,来启动分布式训练。

  2. rank, world_size = get_dist_info() 用于获取当前进程的rank和总进程数。这些信息在分布式训练中很重要,比如确定每个进程应该处理的数据范围等。

  3. cfg.gpu_ids = range(world_size) 将GPU ID列表设置为 [0, 1, ..., world_size-1]。这样每个进程都可以使用一张GPU进行训练。

  4. 接下来是自动恢复训练的逻辑。如果 auto_resume 设置为 True,且没有指定 resume_from 路径,则尝试从工作目录下的 latest.pth 文件恢复训练。

  5. 如果 latest.pth 文件存在,则将其路径赋值给 cfg.resume_from。这样在后续的训练过程中,就可以从上次保存的模型参数继续训练。

总的来说,这部分代码主要完成了以下几个任务:

  1. 初始化分布式训练环境
  2. 获取当前进程的rank和总进程数
  3. 设置GPU ID列表
  4. 实现自动恢复训练的功能

这些设置为后续的模型训练和评估提供了良好的基础。

def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority:
    # config file > default (base filename)
    if cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])

    if not hasattr(cfg, 'dist_params'):
        cfg.dist_params = dict(backend='nccl')

    init_dist(args.launcher, **cfg.dist_params)
    rank, world_size = get_dist_info()
    cfg.gpu_ids = range(world_size)

    auto_resume = cfg.get('auto_resume', True)
    if auto_resume and cfg.get('resume_from', None) is None:
        resume_pth = osp.join(cfg.work_dir, 'latest.pth')
        if osp.exists(resume_pth):
            cfg.resume_from = resume_pth

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.get('log_level', 'INFO'))

这个代码片段展示了一个标准的模型训练流程的初始化步骤。仔细分析一下这些步骤的顺序和目的:

  1. 解析命令行参数:

    • args = parse_args() 解析用户在命令行中传入的参数,如配置文件路径等。
  2. 加载配置文件:

    • cfg = Config.fromfile(args.config) 从用户指定的配置文件中加载实验配置。
  3. 设置 CUDNN Benchmark:

    • if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True 如果配置文件中设置了 cudnn_benchmark,则开启 CUDNN 的 benchmark 模式,这可以加速 CUDA 操作。
  4. 确定工作目录:

    • if cfg.get('work_dir', None) is None: cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) 如果配置文件中没有指定工作目录,则使用配置文件名作为默认工作目录。
  5. 初始化分布式训练:

    • init_dist(args.launcher, **cfg.dist_params) 根据用户指定的启动器和分布式训练参数初始化分布式训练环境。
    • rank, world_size = get_dist_info() 获取当前进程的 rank 和总进程数。
    • cfg.gpu_ids = range(world_size) 根据总进程数设置可用的 GPU ID。
  6. 尝试恢复训练:

    • auto_resume = cfg.get('auto_resume', True) 如果配置文件中设置了 auto_resume,则尝试从上次的最新checkpoint恢复训练。
    • if auto_resume and cfg.get('resume_from', None) is None: cfg.resume_from = osp.join(cfg.work_dir, 'latest.pth') 如果没有指定恢复路径,则尝试从工作目录下的 latest.pth 文件中恢复。
  7. 创建工作目录并保存配置:

    • mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 创建工作目录。
    • cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 将配置文件保存到工作目录下。
  8. 初始化日志记录器:

    • timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 获取当前时间戳作为日志文件名。
    • log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 将日志文件保存到工作目录下。
    • logger = get_root_logger(log_file=log_file, log_level=cfg.get('log_level', 'INFO')) 创建全局的日志记录器,并设置日志级别。

总的来说,这段代码主要完成了以下几个任务:

  1. 加载并解析用户的配置文件和命令行参数。
  2. 根据配置信息初始化分布式训练环境。
  3. 尝试从之前的checkpoint恢复训练。
  4. 创建工作目录并保存当前的配置和日志信息。

这些步骤为后续的模型训练和评估奠定了良好的基础。

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info

    # log some basic info
    logger.info(f'Config: {cfg.pretty_text}')

    # set random seeds
    seed = init_random_seed(args.seed)
    logger.info(f'Set random seed to {seed}, deterministic: {args.deterministic}')
    set_random_seed(seed, deterministic=args.deterministic)

    cfg.seed = seed
    meta['seed'] = seed
    meta['config_name'] = osp.basename(args.config)
    meta['work_dir'] = osp.basename(cfg.work_dir.rstrip('/\\'))

    model = build_model(cfg.model)
    if dv(torch.__version__) >= dv('2.0.0') and args.compile:
        model = torch.compile(model)

接下来的代码片段继续完成了模型初始化的一些重要步骤,一一解释:

  1. 初始化元信息字典:

    • meta = dict() 创建一个空的字典 meta 用于记录一些重要的环境信息和随机种子等。
  2. 记录环境信息:

    • env_info_dict = collect_env() 收集当前环境的各种信息,如 Python 版本、PyTorch 版本、CUDA 版本等。
    • env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 将环境信息格式化为字符串。
    • logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line) 将环境信息记录到日志中。
    • meta['env_info'] = env_info 将环境信息保存到 meta 字典中。
  3. 记录配置信息:

    • logger.info(f'Config: {cfg.pretty_text}') 将配置信息记录到日志中。
  4. 设置随机种子:

    • seed = init_random_seed(args.seed) 根据用户指定的种子或自动生成一个随机种子。
    • logger.info(f'Set random seed to {seed}, deterministic: {args.deterministic}') 将随机种子信息记录到日志中。
    • set_random_seed(seed, deterministic=args.deterministic) 设置 PyTorch 和 NumPy 的随机种子。
    • cfg.seed = seed 将随机种子保存到配置中。
    • meta['seed'] = seed 将随机种子保存到 meta 字典中。
    • meta['config_name'] = osp.basename(args.config) 保存配置文件名到 meta 字典中。
    • meta['work_dir'] = osp.basename(cfg.work_dir.rstrip('/\\')) 保存工作目录名到 meta 字典中。
  5. 构建模型:

    • model = build_model(cfg.model) 根据配置文件中的模型参数构建PyTorch模型。
    • if dv(torch.__version__) >= dv('2.0.0') and args.compile: model = torch.compile(model) 如果PyTorch版本大于等于2.0.0,并且用户指定了 compile参数,则使用PyTorch的自动混合精度编译功能优化模型。

总的来说,这一部分代码主要完成了以下几个任务:

  1. 收集和记录当前环境的各种信息,为后续的模型训练和评估提供上下文。
  2. 设置随机种子,确保实验的可重复性。
  3. 根据配置文件构建PyTorch模型,并在可能的情况下进行自动混合精度编译优化。

这些步骤确保了实验的可重复性和模型的性能优化,为后续的训练和评估奠定了基础。

    model = build_model(cfg.model)
    if dv(torch.__version__) >= dv('2.0.0') and args.compile:
        model = torch.compile(model)

    datasets = [build_dataset(cfg.data.train)]

    cfg.workflow = cfg.get('workflow', [('train', 1)])
    assert len(cfg.workflow) == 1
    if cfg.checkpoint_config is not None:
        # save pyskl version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            pyskl_version=__version__ + get_git_hash(digits=7),
            config=cfg.pretty_text)

    test_option = dict(test_last=args.test_last, test_best=args.test_best)

详细解释这部分代码的逻辑:

  1. 构建模型:

    • model = build_model(cfg.model) 根据配置文件中的模型参数构建PyTorch模型。这个过程可能涉及到模型架构的定义、参数初始化等。
  2. 编译模型:

    • if dv(torch.__version__) >= dv('2.0.0') and args.compile: model = torch.compile(model) 如果PyTorch版本大于等于2.0.0,并且用户指定了 compile 参数,则使用PyTorch的自动混合精度编译功能优化模型。这可以提高模型的推理速度和内存利用率。
  3. 构建训练数据集:

    • datasets = [build_dataset(cfg.data.train)] 根据配置文件中的数据参数构建训练数据集。这可能涉及到数据加载、预处理、增强等操作。
  4. 设置工作流程:

    • cfg.workflow = cfg.get('workflow', [('train', 1)]) 从配置文件中获取工作流程,默认为只进行训练一轮。
    • assert len(cfg.workflow) == 1 确保工作流程只有一个步骤。
  5. 设置检查点配置:

    • if cfg.checkpoint_config is not None: cfg.checkpoint_config.meta = dict(pyskl_version=__version__ + get_git_hash(digits=7), config=cfg.pretty_text) 如果配置文件中存在检查点配置,则在检查点元数据中保存当前 pyskl 版本和配置文件内容,以便后续恢复和分析。
  6. 设置测试选项:

    • test_option = dict(test_last=args.test_last, test_best=args.test_best) 根据用户指定的参数,设置是否在训练过程中对最新的模型和最佳的模型进行测试。

总的来说,这一部分代码主要完成了以下几个任务:

  1. 构建PyTorch模型并对其进行编译优化。
  2. 构建训练数据集。
  3. 设置工作流程和检查点配置。
  4. 设置测试选项。

这些步骤为后续的模型训练和评估做好了充分的准备。

https://github.com/kennymckormick/pyskl/blob/main/tools/train.py

  • 9
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值