def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority:
# config file > default (base filename)
if cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
if not hasattr(cfg, 'dist_params'):
cfg.dist_params = dict(backend='nccl')
这段代码是一个 PyTorch 项目的主函数部分,它展示了训练过程中一些常见的配置设置。逐一解释一下这段代码的主要作用和要点:
-
parse_args()
函数用于解析命令行参数,比如指定配置文件的路径等。 -
Config.fromfile(args.config)
用于从配置文件中加载模型的各种超参数和设置,如网络结构、优化器、损失函数等。 -
torch.backends.cudnn.benchmark = True
可以加速 cuDNN 的自动调优过程,在输入数据尺寸固定的情况下提高训练效率。 -
cfg.work_dir
用于指定模型的保存路径。如果配置文件中没有指定,则默认使用配置文件名作为工作目录。 -
cfg.dist_params
设置了分布式训练的后端为 NCCL,这是一种常用的分布式训练协议。
总的来说,这段代码主要完成了以下几个任务:
- 读取命令行参数和配置文件
- 设置 cuDNN 自动调优
- 确定模型保存路径
- 配置分布式训练
这些准备工作为后续的模型训练和评估提供了良好的基础。整体来说,这段代码展现了一个较为标准的 PyTorch 训练流程。
init_dist(args.launcher, **cfg.dist_params)
rank, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
auto_resume = cfg.get('auto_resume', True)
if auto_resume and cfg.get('resume_from', None) is None:
resume_pth = osp.join(cfg.work_dir, 'latest.pth')
if osp.exists(resume_pth):
cfg.resume_from = resume_pth
-
init_dist(args.launcher, **cfg.dist_params)
函数用于初始化分布式训练环境。根据配置文件中的dist_params
设置,如后端类型、进程数量等,来启动分布式训练。 -
rank, world_size = get_dist_info()
用于获取当前进程的rank和总进程数。这些信息在分布式训练中很重要,比如确定每个进程应该处理的数据范围等。 -
cfg.gpu_ids = range(world_size)
将GPU ID列表设置为[0, 1, ..., world_size-1]
。这样每个进程都可以使用一张GPU进行训练。 -
接下来是自动恢复训练的逻辑。如果
auto_resume
设置为 True,且没有指定resume_from
路径,则尝试从工作目录下的latest.pth
文件恢复训练。 -
如果
latest.pth
文件存在,则将其路径赋值给cfg.resume_from
。这样在后续的训练过程中,就可以从上次保存的模型参数继续训练。
总的来说,这部分代码主要完成了以下几个任务:
- 初始化分布式训练环境
- 获取当前进程的rank和总进程数
- 设置GPU ID列表
- 实现自动恢复训练的功能
这些设置为后续的模型训练和评估提供了良好的基础。
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority:
# config file > default (base filename)
if cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
if not hasattr(cfg, 'dist_params'):
cfg.dist_params = dict(backend='nccl')
init_dist(args.launcher, **cfg.dist_params)
rank, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
auto_resume = cfg.get('auto_resume', True)
if auto_resume and cfg.get('resume_from', None) is None:
resume_pth = osp.join(cfg.work_dir, 'latest.pth')
if osp.exists(resume_pth):
cfg.resume_from = resume_pth
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.get('log_level', 'INFO'))
这个代码片段展示了一个标准的模型训练流程的初始化步骤。仔细分析一下这些步骤的顺序和目的:
-
解析命令行参数:
args = parse_args()
解析用户在命令行中传入的参数,如配置文件路径等。
-
加载配置文件:
cfg = Config.fromfile(args.config)
从用户指定的配置文件中加载实验配置。
-
设置 CUDNN Benchmark:
if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True
如果配置文件中设置了cudnn_benchmark
,则开启 CUDNN 的 benchmark 模式,这可以加速 CUDA 操作。
-
确定工作目录:
if cfg.get('work_dir', None) is None: cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
如果配置文件中没有指定工作目录,则使用配置文件名作为默认工作目录。
-
初始化分布式训练:
init_dist(args.launcher, **cfg.dist_params)
根据用户指定的启动器和分布式训练参数初始化分布式训练环境。rank, world_size = get_dist_info()
获取当前进程的 rank 和总进程数。cfg.gpu_ids = range(world_size)
根据总进程数设置可用的 GPU ID。
-
尝试恢复训练:
auto_resume = cfg.get('auto_resume', True)
如果配置文件中设置了auto_resume
,则尝试从上次的最新checkpoint恢复训练。if auto_resume and cfg.get('resume_from', None) is None: cfg.resume_from = osp.join(cfg.work_dir, 'latest.pth')
如果没有指定恢复路径,则尝试从工作目录下的latest.pth
文件中恢复。
-
创建工作目录并保存配置:
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
创建工作目录。cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
将配置文件保存到工作目录下。
-
初始化日志记录器:
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
获取当前时间戳作为日志文件名。log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
将日志文件保存到工作目录下。logger = get_root_logger(log_file=log_file, log_level=cfg.get('log_level', 'INFO'))
创建全局的日志记录器,并设置日志级别。
总的来说,这段代码主要完成了以下几个任务:
- 加载并解析用户的配置文件和命令行参数。
- 根据配置信息初始化分布式训练环境。
- 尝试从之前的checkpoint恢复训练。
- 创建工作目录并保存当前的配置和日志信息。
这些步骤为后续的模型训练和评估奠定了良好的基础。
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
# log some basic info
logger.info(f'Config: {cfg.pretty_text}')
# set random seeds
seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {seed}, deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['config_name'] = osp.basename(args.config)
meta['work_dir'] = osp.basename(cfg.work_dir.rstrip('/\\'))
model = build_model(cfg.model)
if dv(torch.__version__) >= dv('2.0.0') and args.compile:
model = torch.compile(model)
接下来的代码片段继续完成了模型初始化的一些重要步骤,一一解释:
-
初始化元信息字典:
meta = dict()
创建一个空的字典meta
用于记录一些重要的环境信息和随机种子等。
-
记录环境信息:
env_info_dict = collect_env()
收集当前环境的各种信息,如 Python 版本、PyTorch 版本、CUDA 版本等。env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
将环境信息格式化为字符串。logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
将环境信息记录到日志中。meta['env_info'] = env_info
将环境信息保存到meta
字典中。
-
记录配置信息:
logger.info(f'Config: {cfg.pretty_text}')
将配置信息记录到日志中。
-
设置随机种子:
seed = init_random_seed(args.seed)
根据用户指定的种子或自动生成一个随机种子。logger.info(f'Set random seed to {seed}, deterministic: {args.deterministic}')
将随机种子信息记录到日志中。set_random_seed(seed, deterministic=args.deterministic)
设置 PyTorch 和 NumPy 的随机种子。cfg.seed = seed
将随机种子保存到配置中。meta['seed'] = seed
将随机种子保存到meta
字典中。meta['config_name'] = osp.basename(args.config)
保存配置文件名到meta
字典中。meta['work_dir'] = osp.basename(cfg.work_dir.rstrip('/\\'))
保存工作目录名到meta
字典中。
-
构建模型:
model = build_model(cfg.model)
根据配置文件中的模型参数构建PyTorch模型。if dv(torch.__version__) >= dv('2.0.0') and args.compile: model = torch.compile(model)
如果PyTorch版本大于等于2.0.0,并且用户指定了compile
参数,则使用PyTorch的自动混合精度编译功能优化模型。
总的来说,这一部分代码主要完成了以下几个任务:
- 收集和记录当前环境的各种信息,为后续的模型训练和评估提供上下文。
- 设置随机种子,确保实验的可重复性。
- 根据配置文件构建PyTorch模型,并在可能的情况下进行自动混合精度编译优化。
这些步骤确保了实验的可重复性和模型的性能优化,为后续的训练和评估奠定了基础。
model = build_model(cfg.model)
if dv(torch.__version__) >= dv('2.0.0') and args.compile:
model = torch.compile(model)
datasets = [build_dataset(cfg.data.train)]
cfg.workflow = cfg.get('workflow', [('train', 1)])
assert len(cfg.workflow) == 1
if cfg.checkpoint_config is not None:
# save pyskl version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
pyskl_version=__version__ + get_git_hash(digits=7),
config=cfg.pretty_text)
test_option = dict(test_last=args.test_last, test_best=args.test_best)
详细解释这部分代码的逻辑:
-
构建模型:
model = build_model(cfg.model)
根据配置文件中的模型参数构建PyTorch模型。这个过程可能涉及到模型架构的定义、参数初始化等。
-
编译模型:
if dv(torch.__version__) >= dv('2.0.0') and args.compile: model = torch.compile(model)
如果PyTorch版本大于等于2.0.0,并且用户指定了compile
参数,则使用PyTorch的自动混合精度编译功能优化模型。这可以提高模型的推理速度和内存利用率。
-
构建训练数据集:
datasets = [build_dataset(cfg.data.train)]
根据配置文件中的数据参数构建训练数据集。这可能涉及到数据加载、预处理、增强等操作。
-
设置工作流程:
cfg.workflow = cfg.get('workflow', [('train', 1)])
从配置文件中获取工作流程,默认为只进行训练一轮。assert len(cfg.workflow) == 1
确保工作流程只有一个步骤。
-
设置检查点配置:
if cfg.checkpoint_config is not None: cfg.checkpoint_config.meta = dict(pyskl_version=__version__ + get_git_hash(digits=7), config=cfg.pretty_text)
如果配置文件中存在检查点配置,则在检查点元数据中保存当前 pyskl 版本和配置文件内容,以便后续恢复和分析。
-
设置测试选项:
test_option = dict(test_last=args.test_last, test_best=args.test_best)
根据用户指定的参数,设置是否在训练过程中对最新的模型和最佳的模型进行测试。
总的来说,这一部分代码主要完成了以下几个任务:
- 构建PyTorch模型并对其进行编译优化。
- 构建训练数据集。
- 设置工作流程和检查点配置。
- 设置测试选项。
这些步骤为后续的模型训练和评估做好了充分的准备。
https://github.com/kennymckormick/pyskl/blob/main/tools/train.py