本文主要用于自用,(这篇主要是train部分的),记录了一些我看代码的过程以及理解。如果能帮到您一些最好。
使用方法:ctrl+f定位到你想了解的部分。
代码
def init_tb_loggers(opt):
# initialize wandb logger before tensorboard logger to allow proper sync
if ((opt['logger'].get('wandb') is not None)
and (opt['logger']['wandb'].get('project')is not None)
and ('debug' not in opt['name'])):
assert opt['logger'].get('use_tb_logger') is True,
('should turn on tensorboard when using wandb')
init_wandb_logger(opt)
tb_logger = None
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name']))
return tb_logger
- 这个函数的主要功能是初始化实验日志记录器,包括 TensorBoard 和 Weights & Biases (WandB) 两种日志工具。它会根据配置 (
opt
) 决定是否启用这些日志记录器,并返回 TensorBoard 的日志器实例(如果启用)。 opt
:一个字典,包含训练的所有配置信息(如日志设置、实验名称等)。tb_logger
:TensorBoard 日志记录器实例(如果启用),否则返回None
。
相应配置文件的内容如下:
# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~
初始化 WandB
日志记录器(如果配置要求):
- 检查是否启用
WandB
:opt['logger'].get('wandb') is not None
:配置中是否定义了wandb
字段。opt['logger']['wandb'].get('project') is not None
:WandB 配置中是否指定了项目名称 (project
)。'debug' not in opt['name']
:实验名称 (opt['name']
) 是否不包含debug
(避免在调试模式下启用WandB
)。
- 验证
TensorBoard
是否启用:- 如果启用
WandB
,则必须同时启用TensorBoard
(opt['logger'].get('use_tb_logger') is True
),否则抛出异常。
- 如果启用
- 初始化
WandB
日志记录器:- 调用
init_wandb_logger(opt)
初始化WandB
。
- 调用
WandB
是一个实验跟踪工具,用于记录训练指标、超参数等。初始化后,训练数据会自动同步到WandB
的云端或本地服务器。
初始化 TensorBoard
日志记录器(如果配置要求)
- 检查是否启用
TensorBoard
:opt['logger'].get('use_tb_logger')
:配置中是否启用了TensorBoard
。'debug' not in opt['name']
:实验名称是否不包含debug
(避免在调试模式下启用TensorBoard
)。
- 初始化
TensorBoard
日志记录器:- 调用
init_tb_logger
- 调用
TensorBoard
是一个可视化工具,用于记录训练过程中的指标(如损失、准确率等)。日志会保存在指定目录中,后续可以通过TensorBoard
查看。
train.py中opt[‘root_path’],但配置文件中并没有‘root_path’
所以对于该项目,确实初始化了wandb
但没用初始化tensorboard
def create_train_val_dataloader(opt, logger):
# create train and val dataloaders
train_loader, val_loaders = None, []
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
train_set = build_dataset(dataset_opt)
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
train_loader = build_dataloader(
train_set,
dataset_opt,
num_gpu=opt['num_gpu'],
dist=opt['dist'], #dist:是否启用分布式训练。
sampler=train_sampler, #sampler:指定自定义采样器。
seed=opt['manual_seed'])
num_iter_per_epoch = math.ceil(
len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
total_iters = int(opt['train']['total_iter'])
total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
logger.info('Training statistics:'
f'\n\tNumber of train images: {len(train_set)}'
f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
f'\n\tWorld size (gpu number): {opt["world_size"]}'
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
elif phase.split('_')[0] == 'val':
val_set = build_dataset(dataset_opt)
val_loader = build_dataloader(
val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
val_loaders.append(val_loader)
# 验证集不需要分布式采样器 (sampler=None)。
# 支持多个验证集(如 val_set5、val_set14)
else:
raise ValueError(f'Dataset phase {phase} is not recognized.')
return train_loader, train_sampler, val_loaders, total_epochs, total_iters
- 输入参数:
logger
:日志记录器,用于输出信息。
- 返回值:
train_loader
:训练集的数据加载器。train_sampler
:训练集的采样器(用于分布式训练)。val_loaders
:列表(支持多个验证集),每个元素是一个验证集的DataLoader
。
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
分布式训练中数据分片和动态扩展的核心实现,它的作用是创建一个自定义的分布式采样器,确保多GPU训练时数据的高效分配和潜在的数据集增强。
world_size
:参与训练的GPU总数rank
:当前GPU的全局编号(0到world_size-1
)
EnlargedSampler
:
-
自定义采样器,支持(功能):
- 分布式训练(多 GPU 数据分片)。
- 在分布式训练中,直接让所有GPU加载相同数据会导致:
- 计算冗余(多GPU重复计算相同梯度)。
- 批量梯度估计偏差(实际等效批量大小虚假增大)。
- 在分布式训练中,直接让所有GPU加载相同数据会导致:
- 数据集放大(通过重复采样)
- 避免重复
- 每轮epoch开始时,重新打乱数据顺序并重新分片,确保:
- 不同epoch看到不同的数据组合。
- 避免固定分片导致的训练偏差。
- 每轮epoch开始时,重新打乱数据顺序并重新分片,确保:
- 分布式训练(多 GPU 数据分片)。
-
总批量大小:
batch_size_per_gpu * world_size
(如16×4=64)。
对应的配置文件:
datasets:
train:
task: SR
name: DF2K
type: PairedImageDataset
dataroot_gt:
- /data/home/sczc338/run/MambaIR/datasets/DF2K/HR
dataroot_lq:
- /data/home/sczc338/run/MambaIR/datasets/DF2K/LR_bicubic/X2
filename_tmpl: '{}x2'
io_backend:
type: disk
gt_size: 128 #训练时随机裁剪的HR图像块大小为128x128像素。
use_hflip: true #启用随机水平翻转作为数据增强。
use_rot: true #启用随机旋转作为数据增强。
# data loader
use_shuffle: true
num_worker_per_gpu: 32 #每个GPU使用32个工作线程加载数据。
batch_size_per_gpu: 1
dataset_enlarge_ratio: 100 #数据集放大比率(可能是虚拟放大,通过重复数据)。
prefetch_mode: ~
val:
name: Set14
type: PairedImageDataset
dataroot_gt: /data/home/sczc338/run/MambaIR/datasets/SR/Set14/HR
dataroot_lq: /data/home/sczc338/run/MambaIR/datasets/SR/Set14/LR_bicubic/X2
filename_tmpl: '{}x2'
io_backend:
type: disk
则for phase, dataset_opt in opt['datasets'].items():
相当于for key,value in dict.items():
opt['datasets']
是一个字典(opt
也是一个字典)。对于opt['datasets']
这个字典来说,phase
是这个字典的键(train
和val
),dataset_opt
是这个字典的值。
dataset_opt
本身也是一个字典,其键为task
,name
,type
…
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
dataset_enlarge_ratio
:- 数据集放大比例(默认
1
),通过重复采样使小数据集在训练时“变大”。 - 例如:若原始数据集有 1000 张图,
ratio=2
等效于 2000 张图。
- 数据集放大比例(默认
num_iter_per_epoch = math.ceil(
len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
total_iters = int(opt['train']['total_iter'])
total_epochs = math.ceil(total_iters / num_iter_per_epoch)
- 每个 epoch 的迭代次数 =
数据集长度 × 放大比例 / (每 GPU 批大小 × GPU 总数)
例如:1000 张图 × 2(放大) / (16 × 4 GPU) = 31.25 → 32 次迭代/epoch。
- 总 epoch 数 =
总迭代次数 / 每 epoch 迭代次数
。
elif phase.split('_')[0] == 'val':
若配置文件为:
datasets: train: # 训练集 name: DIV2K val_Set5: # 验证集1 name: Set5 val_Set14: # 验证集2 name: Set14
.split('_')
->以下划线 _
为分隔符拆分字符串,返回列表。
split('_')[0]
->取拆分后的第一个部分(即前缀)
由于本项目的配置文件相应的部分为val
而不是val_set5
等,所以可是直接改为elif phase== 'val':
def load_resume_state(opt):
resume_state_path = None
if opt['auto_resume']:
state_path = osp.join('experiments', opt['name'], 'training_states')
if osp.isdir(state_path):
states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
if len(states) != 0:
states = [float(v.split('.state')[0]) for v in states]
resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
opt['path']['resume_state'] = resume_state_path
else:
if opt['path'].get('resume_state'):
resume_state_path = opt['path']['resume_state']
if resume_state_path is None:
resume_state = None
else:
device_id = torch.cuda.current_device()
resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
check_resume(opt, resume_state['iter'])
return resume_state
这个函数的作用是加载训练中断时的恢复状态(如模型权重、优化器状态、当前迭代次数等),支持自动检测和手动指定恢复路径两种模式。
resume_state
:恢复状态的字典(如未找到则返回 None
)。
check_resume(opt, resume_state['iter'])
- 调用
check_resume
检查恢复的迭代次数是否合法(如不超过配置的总迭代次数)
states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
scandir()
:扫描目录suffix='state'
:只匹配以.state
结尾的文件recursive=False
:不递归扫描子目录full_path=False
:返回文件名而非完整路径- 返回文件名列表,如:
['5000.state', '10000.state', '15000.state']
if len(states) != 0:
states = [float(v.split('.state')[0]) for v in states]
resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
opt['path']['resume_state']= resume_state_path
这段代码实现的功能是保存最新的状态文件路径。
def train_pipeline(root_path):
# parse options, set distributed setting, set ramdom seed
opt, args = parse_options(root_path, is_train=True)
opt['root_path'] = root_path
# 启用CuDNN的自动优化器以加速训练。
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
# load resume states if necessary
resume_state = load_resume_state(opt)
# mkdir for experiments and logger
# 如果是首次训练(非恢复训练),创建实验目录和日志目录。
if resume_state is None and opt['rank'] == 0:
make_exp_dirs(opt)
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0:
mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
# copy the yml file to the experiment root
copy_opt_file(args.opt, opt['path']['experiments_root'])
# WARNING: should not use get_root_logger in the above codes, including the called functions
# Otherwise the logger will not be properly initialized
# 初始化日志记录器,记录环境和配置信息。
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
logger.info(get_env_info())
logger.info(dict2str(opt))
# initialize wandb and tb loggers
tb_logger = init_tb_loggers(opt)
# create train and validation dataloaders
result = create_train_val_dataloader(opt, logger)
train_loader, train_sampler, val_loaders, total_epochs, total_iters = result
# create model
model = build_model(opt)
if resume_state: # resume training
model.resume_training(resume_state) # handle optimizers and schedulers
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.")
start_epoch = resume_state['epoch']
current_iter = resume_state['iter']
else:
start_epoch = 0
current_iter = 0
# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger)
# dataloader prefetcher
# 创建消息日志记录器,用于格式化输出训练信息。
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
if prefetch_mode is None or prefetch_mode == 'cpu':
prefetcher = CPUPrefetcher(train_loader)
elif prefetch_mode == 'cuda':
prefetcher = CUDAPrefetcher(train_loader, opt)
logger.info(f'Use {prefetch_mode} prefetch dataloader')
if opt['datasets']['train'].get('pin_memory') is not True:
raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
else:
raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.")
# training
logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
data_timer, iter_timer = AvgTimer(), AvgTimer()
start_time = time.time()
for epoch in range(start_epoch, total_epochs + 1):
train_sampler.set_epoch(epoch)
prefetcher.reset()
train_data = prefetcher.next()
while train_data is not None:
data_timer.record()
current_iter += 1
if current_iter > total_iters:
break
# update learning rate
model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
# training
model.feed_data(train_data)
model.optimize_parameters(current_iter)
iter_timer.record()
if current_iter == 1:
# reset start time in msg_logger for more accurate eta_time
# not work in resume mode
msg_logger.reset_start_time()
# log
if current_iter % opt['logger']['print_freq'] == 0:
log_vars = {'epoch': epoch, 'iter': current_iter}
log_vars.update({'lrs': model.get_current_learning_rate()})
log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
log_vars.update(model.get_current_log())
msg_logger(log_vars)
# save models and training states
if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
logger.info('Saving models and training states.')
model.save(epoch, current_iter)
# validation
if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
if len(val_loaders) > 1:
logger.warning('Multiple validation datasets are *only* supported by SRModel.')
for val_loader in val_loaders:
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
data_timer.start()
iter_timer.start()
train_data = prefetcher.next()
# end of iter
# end of epoch
consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
logger.info(f'End of training. Time consumed: {consumed_time}')
logger.info('Save the latest model.')
model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
if opt.get('val') is not None:
for val_loader in val_loaders:
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
if tb_logger:
tb_logger.close()
实现端到端的训练流程,包括配置解析、环境初始化、数据加载、模型训练、验证和日志记录。
输出:训练完成的模型和日志文件
1.初始化阶段:
opt, args = parse_options(root_path, is_train=True)
opt['root_path'] = root_path
torch.backends.cudnn.benchmark = True
root_path
:项目根目录路径(通常通过os.path.abspath(__file__)
动态获取opt
:解析配置文件生成的torch.backends.cudnn.benchmark = True
:启用CuDNN加速- 训练速度提升约10-30%(尤其对CNN模型)
- 输入尺寸变化频繁时(如可变分辨率),建议关闭以避免性能下降
args
:命令行参数(如覆盖配置文件的临时参数)
2.恢复训练处理:
resume_state = load_resume_state(opt)
if resume_state is None and opt['rank'] == 0:
make_exp_dirs(opt)
mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
-
resume_state is None
:没有找到可恢复的状态(全新训练/首次训练) -
opt['rank'] == 0
:仅由主GPU进程(rank=0)执行目录操作,避免多进程竞争 -
make_exp_dirs
:创建实验所需的核心目录结构
典型结构为:
experiments/
└── exp1/ # 实验名称
├── models/ # 保存模型权重(如latest.pth, best.pth)
├── training_states/ # 训练状态文件(如10000.state)
└── log/ # 训练日志文件
mkdir_and_rename
:TensorBoard
目录,如果root_path/tb_logger/name
这个路径不存在就创建,如果存在就重命名为带时间戳的形式,如:root_path/tb_logger/name_20250430
3.日志系统初始化
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
logger = get_root_logger(log_file=log_file)
tb_logger = init_tb_loggers(opt)
- 文件日志(记录训练过程细节)
TensorBoard/WandB
日志(可视化指标)
4.数据加载与模型构建
5.训练恢复处理
6.训练主循环
for epoch in range(start_epoch, total_epochs + 1):
train_sampler.set_epoch(epoch) # 分布式训练关键步骤
prefetcher.reset()
while train_data is not None:
current_iter += 1
model.update_learning_rate(current_iter)
model.feed_data(train_data)
model.optimize_parameters(current_iter)
-
train_sampler.set_epoch(epoch)
:在分布式训练中,DistributedSampler
需要知道当前epoch以重新打乱数据分片。确保不同GPU在不同epoch看到不同的数据组合。
(应该是随机种子固定了,所以每一GPU每一epoch得到的数据分片都是固定的,从而需要知道当前的epoch,避免看到重复的数据切片) -
prefetcher.reset()
:清除预取缓冲区,准备新一轮数据加载预取模式:
-
CPUPrefetcher
:多线程将数据从磁盘预读到CPU内存,适合数据量小或CPU瓶颈。 -
CUDAPrefetcher
:直接将数据预取到GPU显存,大数据/高性能GPU。
7.训练过程监督
8.训练收尾
- 保存最终模型
- 关闭
tensorboard
写入
接下来就是看它所有导入的函数的具体实现:
from basicsr.data import build_dataloader, build_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model
from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
from basicsr.utils.options import copy_opt_file, dict2str, parse_options
-
from basicsr.data import build_dataloader
-
from basicsr.data import build_dataset
-
from basicsr.data.data_sampler import EnlargedSampler
-
from basicsr.data.prefetch_dataloader import CPUPrefetcher
-
from basicsr.data.prefetch_dataloader import CUDAPrefetcher
-
from basicsr.models import build_model
def build_model(opt):
"""Build model from options.
Args:
opt (dict): Configuration. It must contain:
model_type (str): Model type.
"""
opt = deepcopy(opt)
model = MODEL_REGISTRY.get(opt['model_type'])(opt)
logger = get_root_logger()
logger.info(f'Model [{model.__class__.__name__}] is created.')
return model
-
MODEL_REGISTRY
:模型注册表(通常为全局字典)- 示例注册表:
MODEL_REGISTRY = { 'SRGAN': SRGANModel, 'EDSR': EDSRModel, 'MambaIR': MambaIRModel }
- 工作流程:
- 通过
opt['model_type']
获取模型类(如SRGANModel
)(SRGANModel
,EDSRModel
,MambaIRModel
都是类!) - 实例化模型并传入配置
opt
-
注册表实现示例:
from basicsr.utils.registry import Registry MODEL_REGISTRY = Registry('model') @MODEL_REGISTRY.register() class SRGANModel(nn.Module): def __init__(self, opt): super().__init__() self.generator = build_generator(opt) self.discriminator = build_discriminator(opt) # 注册其他模型 @MODEL_REGISTRY.register() class MambaIRModel(nn.Module): ...
-
from basicsr.utils.options import copy_opt_file
-
from basicsr.utils.options import dict2str
-
from basicsr.utils.options import parse_options
def parse_options(root_path, is_train=True):
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, default='../options/train/train_MambaIR_SR_x2.yml', help='Path to option YAML file.')
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--local-rank', type=int, default=0) # for pytorch 2.0
parser.add_argument('--local_rank', type=int, default=0) # for pytorch < 2.0
parser.add_argument('--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
args = parser.parse_args()
# parse yml to dict
with open(args.opt, mode='r') as f:
opt = yaml.load(f, Loader=ordered_yaml()[0])
# distributed settings
if args.launcher == 'none':
opt['dist'] = False
print('Disable distributed.', flush=True)
else:
opt['dist'] = True
if args.launcher == 'slurm' and 'dist_params' in opt:
init_dist(args.launcher, **opt['dist_params'])
else:
init_dist(args.launcher)
opt['rank'], opt['world_size'] = get_dist_info()
# 获取当前进程的rank和总进程数
# random seed
seed = opt.get('manual_seed')
if seed is None:
seed = random.randint(1, 10000)
opt['manual_seed'] = seed
set_random_seed(seed + opt['rank'])
# force to update yml options
if args.force_yml is not None:
for entry in args.force_yml:
# now do not support creating new keys
keys, value = entry.split('=')
keys, value = keys.strip(), value.strip()
value = _postprocess_yml_value(value)
eval_str = 'opt'
for key in keys.split(':'):
eval_str += f'["{key}"]'
eval_str += '=value'
# using exec function
exec(eval_str)
opt['auto_resume'] = args.auto_resume
opt['is_train'] = is_train
# debug setting
if args.debug and not opt['name'].startswith('debug'):
opt['name'] = 'debug_' + opt['name']
# 自动根据当前机器设置GPU数量。
if opt['num_gpu'] == 'auto':
opt['num_gpu'] = torch.cuda.device_count()
# datasets
for phase, dataset in opt['datasets'].items():
# for multiple datasets, e.g., val_1, val_2; test_1, test_2
phase = phase.split('_')[0]
dataset['phase'] = phase
if 'scale' in opt:
dataset['scale'] = opt['scale']
# paths
for key, val in opt['path'].items():
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
opt['path'][key] = osp.expanduser(val)
if is_train:
experiments_root = osp.join(root_path, 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
opt['path']['log'] = experiments_root
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
# change some options for debug mode
if 'debug' in opt['name']:
if 'val' in opt:
opt['val']['val_freq'] = 8
opt['logger']['print_freq'] = 1
opt['logger']['save_checkpoint_freq'] = 8
else: # test
results_root = osp.join(root_path, 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = results_root
opt['path']['visualization'] = osp.join(results_root, 'visualization')
return opt, args
args
:命令行参数对象
-
命令行参数解析(argparse)
parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, default='../options/train/train_MambaIR_SR_x2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none') parser.add_argument('--auto_resume', action='store_true') parser.add_argument('--debug', action='store_true') parser.add_argument('--local-rank', type=int, default=0) # PyTorch 2.0+ parser.add_argument('--local_rank', type=int, default=0) # PyTorch <2.0 parser.add_argument('--force_yml', nargs='+', default=None) args = parser.parse_args()
-
-opt
:主配置文件路径 -
--launcher
:分布式启动方式none
使用场景:单机单卡/多卡(非分布式)。技术实现:直接运行脚本pytorch
使用场景:单机多卡/多机多卡分布式训练。技术实现:使用torch.distributed.launch
或torchrun
slurm
使用场景:多机多卡集群训练。技术实现:通过SLURM作业调度系统启动
维度 none
pytorch
slurm
底层技术 单进程 torch.distributed
torch.distributed
+ Slurm资源管理 用户手动控制 用户手动控制 Slurm自动分配 扩展性 低(单机级) 中(多机级) 高(集群级) 典型命令示例 python train.py
torchrun --nproc_per_node=4 train.py
sbatch job_script.slurm
-
--auto_resume
:自动恢复标志。- 动作类型:
store_true
(当用户在命令行中指定该参数时,--auto_resume
会被设为True
,否则为False
。) - 示例:
python train.py --auto_resume
- 动作类型:
-
--debug
:调试标志(同理) -
--force_yml
: 动态配置覆盖。nargs
:+
表示接受1个或多个参数。- 通过
exec
动态执行赋值,如:opt["train"]["lr"] = 0.001
- 优先级最高:
--force_yml
命令行参数(eg:--force_yml "train:lr=0.01"
)>命令行直接参数(eg:--launcher slurm
)>YAML配置文件默认值
-
args
-
包含了所有通过命令行传入的参数及其对应的值。
-
未指定的参数使用默认值
-
典型
args
对象:-
Namespace( opt='configs/train.yml', launcher='pytorch', auto_resume=True, debug=False, local_rank=0, force_yml=None )
-
-
示例1:当运行脚本时输入:
python train.py --launcher pytorch
。得到的args
对象为Namespace(launcher=pytorch)
-
示例2:当运行脚本时输入:
python train.py -opt my_config.yml --debug
。得到的args
对象为:
Namespace(opt='my_config.yml', debug=True)
-
属性访问方式:
print(args.opt)
# 输出:my_config.yml
print(args.debug)
# 输出:True
-
为什么需要
args
而不用直接读YAML?:- 动态性:允许运行时修改配置(如临时调大学习率)
- 灵活性:同一份YAML可衍生不同实验(通过不同命令行参数)
- 可读性:
args.lr
比config['train']['lr']
更直观
-
如何查看所有可用参数?:
python train.py --help
# 自动生成帮助信息 -
与
yaml
配置文件的关系:
维度 YAML配置文件 命令行args 存储形式 结构化文本文件(层次化键值对) 扁平化命名空间对象 主要用途 定义默认训练参数(模型结构、数据集路径等) 运行时临时覆盖配置/开关控制 修改频率 低频(不同实验间修改) 高频(单实验多次调试时修改) 适用场景 需要版本控制的固定参数 临时性调试参数 优先级:命令行参数>YAML配置
-
-
debug模式:
-
命令行直接启动:
python train.py --debug
# 添加–debug参数即可 -
组合其他参数:(调试模式+修改批量大小+提高日志频率)
python train.py --debug --force_yml "datasets:train:batch_size_per_gpu=2" "logger:print_freq=1"
-
-
其它训练方式:
- 基础训练:
python train.py -opt options/train.yml
- 分布式训练(2GPU):
torchrun --nproc_per_node=2 train.py \ --launcher pytorch \ --opt options/train.yml \ --auto_resume
-
python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 basicsr/train.py -opt options/train/mambairv2/train_MambaIRv2_SR_x2.yml --launcher pytorch
- 基础训练:
with open(args.opt, mode='r') as f:
opt = yaml.load(f, Loader=ordered_yaml()[0])
-
使用
yaml.load
读取yaml
配置文件,ordered_yaml()
确保字段顺序保留 -
ordered_yaml()[0]
:
自定义的加载器,用于保持 YAML 中字段的原始顺序(普通yaml.load
可能返回无序字典) -
为什么需要保持顺序:打印或日志输出时,有序字典更易于定位参数。
-
典型 YAML 配置文件示例
# options/train.yml name: "MambaIR_SR_x2" # 实验名称 model: type: "MambaIR" # 模型类型 scale: 2 # 超分辨率缩放因子 datasets: train: dataroot_gt: "data/DIV2K/train" # 训练集路径 batch_size: 16
-
解析后
opt
的结构:{ 'name': 'MambaIR_SR_x2', 'model': {'type': 'MambaIR', 'scale': 2}, 'datasets': {'train': {'dataroot_gt': 'data/DIV2K/train', 'batch_size': 16}} }
opt['rank'], opt['world_size'] = get_dist_info()
- 获取当前进程的rank和总进程数
- Rank是一个从
0
开始的整数,用于标识分布式训练中每个进程的唯一身份,决定了数据分配、通信顺序和任务分工。 rank
的作用:- 数据分配:决定当前进程处理哪部分数据(如DDP中,rank=0的进程可能负责验证日志记录)
- 梯度同步
- 模型保存:通常仅由rank=0的进程保存模型,避免重复存储
- 日志控制:仅rank=0进程打印日志,避免多进程重复输出
rank
的取值规则:- 单机多卡(1台机器,N张GPU):
rank ∈ [0, N-1]
- 多机多卡(M台机器,每台N张GPU):全局rank范围:
[0, M×N-1]
。每台机器的local_rank(即单机内的GPU编号)范围:[0, N-1]
- 单机多卡(1台机器,N张GPU):
if args.force_yml is not None:
for entry in args.force_yml:
keys, value = entry.split('=')
eval_str = 'opt' + ''.join(f'["{k}"]' for k in keys.split(':')) + '=value'
exec(eval_str) # 执行动态赋值
- 示例:
--force_yml "train:lr=0.001"
会修改opt['train']['lr']
的值。
parser = argparse.ArgumentParser() # 创建参数解析器
args = parser.parse_args() # 解析命令行参数
- 从命令行(如
python train.py --lr 0.01 --batch_size 32
)中读取用户输入的参数,并将其转换为 Python 对象供程序使用。
if args.debug and not opt['name'].startswith('debug'):
opt['name'] = 'debug_' + opt['name'] # 自动添加debug前缀# 修改实验名称
#如name: "MambaIR_SR_x2"--->name: "debug_MambaIR_SR_x2"
# 调试模式下调整验证频率和日志频率
if 'val' in opt:
opt['val']['val_freq'] = 8
opt['logger']['print_freq'] = 1
- 用于 调试模式(Debug Mode)的自动化配置,当用户通过命令行参数
--debug
启用调试模式时,它会自动修改实验配置以方便开发调试
for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0] # 处理多数据集分支(如val_1, val_2)
dataset['phase'] = phase
if 'scale' in opt:
dataset['scale'] = opt['scale'] # 统一超分辨率缩放因子
-
这里的phase是什么?
opt['datasets']
是一个字典,.items()
来获取键值对。phase
就是键key
-
作用:
datasets
中可能长这样:datasets: train: dataroot_gt: 'data/DIV2K/train' val_1: dataroot_gt: 'data/DIV2K/val' val_2: dataroot_gt: 'data/Set5' test_1: dataroot_gt: 'data/Set5'
for phase, dataset in opt['datasets'].items()
得到的phase
为train,val_1,val_2,test_1
。dataset得到的是dataroot_gt
那么经过处理后得到的:datasets: train: dataroot_gt: 'data/DIV2K/train' phase:train val_1: dataroot_gt: 'data/DIV2K/val' phase:val val_2: dataroot_gt: 'data/Set5' phase:val test_1: dataroot_gt: 'data/Set5' phase:test
从而实现区分训练/验证/测试的数据,便于后续环节的进行。
opt&args&-opt
维度 | opt (配置字典) | args (命令行对象) | -opt (命令行参数) |
---|---|---|---|
数据类型 | 嵌套字典 | 命名空间对象(flat) | 字符串(文件路径) |
来源 | YAML配置文件 | 命令行输入 | 命令行输入(-opt 参数) |
优先级 | 基础配置 | 可覆盖opt | 仅决定opt 的加载路径 |
典型用途 | 定义模型结构、超参数等 | 临时调整实验行为 | 指定配置文件路径 |
opt
(配置字典)
-
来源:从 YAML配置文件 加载(如
train_MambaIR_SR_x2.yml
),通过yaml.load()
解析为Python字典。 -
内容:包含模型、数据集、训练超参数等所有静态配置。
-
特点:
- 层级化结构(嵌套字典),适合管理复杂配置。
- 通常与代码解耦,便于实验复现。
-
示例:
# options/train.yml name: "exp1" model: type: "MambaIR" scale: 2
解析后:
opt = { 'name': 'exp1', 'model': {'type': 'MambaIR', 'scale': 2} }
args
(命令行参数对象)
-
来源:通过
argparse.ArgumentParser()
解析的命令行参数。 -
内容:用户运行时临时指定的动态覆盖选项。
-
特点:
- 扁平结构(单层属性),适合快速调整关键参数。
- 优先级高于
opt
(可覆盖YAML配置)。
-
示例:
python train.py --launcher pytorch --debug
解析后:
args.launcher = "pytorch" args.debug = True
-opt
(命令行参数中的配置路径)
-
来源:
args
的一个特殊字段,通过parser.add_argument('-opt', ...)
定义。 -
作用:指定YAML配置文件的路径,是连接
args
和opt
的桥梁。 -
示例:
python train.py -opt options/train.yml # 指向YAML文件
代码中通过
args.opt
获取路径后加载配置:with open(args.opt) as f: opt = yaml.load(f) # 生成opt字典
from basicsr.utils import
-
AvgTimer
-
MessageLogger
-
check_resume
-
get_env_info
-
get_root_logger
-
get_time_str
-
init_tb_logger
-
init_wandb_logger
-
make_exp_dirs
-
mkdir_and_rename
-
scandir