MambaIRv2: Attentive State Space Restoration代码解读(train部分)

本文主要用于自用,(这篇主要是train部分的),记录了一些我看代码的过程以及理解。如果能帮到您一些最好。
使用方法:ctrl+f定位到你想了解的部分。
代码

def init_tb_loggers(opt):
    # initialize wandb logger before tensorboard logger to allow proper sync
    if ((opt['logger'].get('wandb') is not None)
            and (opt['logger']['wandb'].get('project')is not None)
            and ('debug' not in opt['name'])):
        assert opt['logger'].get('use_tb_logger') is True, 
        ('should turn on tensorboard when using wandb')
        init_wandb_logger(opt)
    tb_logger = None
    if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
        tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name']))
    return tb_logger
  • 这个函数的主要功能是初始化实验日志记录器,包括 TensorBoard 和 Weights & Biases (WandB) 两种日志工具。它会根据配置 (opt) 决定是否启用这些日志记录器,并返回 TensorBoard 的日志器实例(如果启用)。
  • opt:一个字典,包含训练的所有配置信息(如日志设置、实验名称等)。
  • tb_logger:TensorBoard 日志记录器实例(如果启用),否则返回 None

相应配置文件的内容如下:

# logging settings
logger:
  print_freq: 200
  save_checkpoint_freq: !!float 5e3
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~

初始化 WandB 日志记录器(如果配置要求):

  1. 检查是否启用 WandB
    • opt['logger'].get('wandb') is not None:配置中是否定义了 wandb 字段。
    • opt['logger']['wandb'].get('project') is not None:WandB 配置中是否指定了项目名称 (project)。
    • 'debug' not in opt['name']:实验名称 (opt['name']) 是否不包含 debug(避免在调试模式下启用 WandB)。
  2. 验证 TensorBoard 是否启用
    • 如果启用 WandB,则必须同时启用 TensorBoardopt['logger'].get('use_tb_logger') is True),否则抛出异常。
  3. 初始化 WandB 日志记录器
    • 调用 init_wandb_logger(opt) 初始化 WandB
  • WandB 是一个实验跟踪工具,用于记录训练指标、超参数等。初始化后,训练数据会自动同步到 WandB 的云端或本地服务器。

初始化 TensorBoard 日志记录器(如果配置要求)

  1. 检查是否启用 TensorBoard
    • opt['logger'].get('use_tb_logger'):配置中是否启用了 TensorBoard
    • 'debug' not in opt['name']:实验名称是否不包含 debug(避免在调试模式下启用 TensorBoard)。
  2. 初始化 TensorBoard 日志记录器
    • 调用 init_tb_logger
  • TensorBoard 是一个可视化工具,用于记录训练过程中的指标(如损失、准确率等)。日志会保存在指定目录中,后续可以通过 TensorBoard 查看。

train.py中opt[‘root_path’],但配置文件中并没有‘root_path’

所以对于该项目,确实初始化了wandb但没用初始化tensorboard

def create_train_val_dataloader(opt, logger):
    # create train and val dataloaders
    train_loader, val_loaders = None, []
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
            train_set = build_dataset(dataset_opt)
            train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
            
            train_loader = build_dataloader(
                train_set,
                dataset_opt,
                num_gpu=opt['num_gpu'],
                dist=opt['dist'],  #dist:是否启用分布式训练。
                sampler=train_sampler,  #sampler:指定自定义采样器。
                seed=opt['manual_seed'])

            num_iter_per_epoch = math.ceil(
                len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
            total_iters = int(opt['train']['total_iter'])
            total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
            logger.info('Training statistics:'
                        f'\n\tNumber of train images: {len(train_set)}'
                        f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
                        f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
                        f'\n\tWorld size (gpu number): {opt["world_size"]}'
                        f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
                        f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
        elif phase.split('_')[0] == 'val':
            val_set = build_dataset(dataset_opt)
            val_loader = build_dataloader(
                val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
            logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
            val_loaders.append(val_loader)
            # 验证集不需要分布式采样器 (sampler=None)。
            # 支持多个验证集(如 val_set5、val_set14)
        else:
            raise ValueError(f'Dataset phase {phase} is not recognized.')

    return train_loader, train_sampler, val_loaders, total_epochs, total_iters
  • 输入参数
    • logger:日志记录器,用于输出信息。
  • 返回值
    • train_loader:训练集的数据加载器。
    • train_sampler:训练集的采样器(用于分布式训练)。
    • val_loaders列表(支持多个验证集),每个元素是一个验证集的 DataLoader
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)

分布式训练中数据分片和动态扩展的核心实现,它的作用是创建一个自定义的分布式采样器,确保多GPU训练时数据的高效分配和潜在的数据集增强。

  • world_size:参与训练的GPU总数
  • rank:当前GPU的全局编号(0到world_size-1

EnlargedSampler

  • 自定义采样器,支持(功能):

    • 分布式训练(多 GPU 数据分片)。
      • 在分布式训练中,直接让所有GPU加载相同数据会导致:
        • 计算冗余(多GPU重复计算相同梯度)。
        • 批量梯度估计偏差(实际等效批量大小虚假增大)。
    • 数据集放大(通过重复采样)
    • 避免重复
      • 每轮epoch开始时,重新打乱数据顺序并重新分片,确保:
        • 不同epoch看到不同的数据组合。
        • 避免固定分片导致的训练偏差。
  • 总批量大小:batch_size_per_gpu * world_size(如16×4=64)。

对应的配置文件:

datasets:
  train:
    task: SR
    name: DF2K
    type: PairedImageDataset
    dataroot_gt:
    - /data/home/sczc338/run/MambaIR/datasets/DF2K/HR
    dataroot_lq:
    - /data/home/sczc338/run/MambaIR/datasets/DF2K/LR_bicubic/X2
    filename_tmpl: '{}x2'
    io_backend:
      type: disk

    gt_size: 128 #训练时随机裁剪的HR图像块大小为128x128像素。
    use_hflip: true #启用随机水平翻转作为数据增强。
    use_rot: true #启用随机旋转作为数据增强。

    # data loader
    use_shuffle: true
    num_worker_per_gpu: 32 #每个GPU使用32个工作线程加载数据。
    batch_size_per_gpu: 1
    dataset_enlarge_ratio: 100 #数据集放大比率(可能是虚拟放大,通过重复数据)。
    prefetch_mode: ~

  val:
    name: Set14
    type: PairedImageDataset
    dataroot_gt: /data/home/sczc338/run/MambaIR/datasets/SR/Set14/HR
    dataroot_lq: /data/home/sczc338/run/MambaIR/datasets/SR/Set14/LR_bicubic/X2
    filename_tmpl: '{}x2'
    io_backend:
      type: disk

for phase, dataset_opt in opt['datasets'].items():相当于for key,value in dict.items():

opt['datasets']是一个字典(opt也是一个字典)。对于opt['datasets']这个字典来说,phase是这个字典的键(trainval),dataset_opt是这个字典的值。

dataset_opt本身也是一个字典,其键为task,name,type

dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
  • dataset_enlarge_ratio
    • 数据集放大比例(默认 1),通过重复采样使小数据集在训练时“变大”。
    • 例如:若原始数据集有 1000 张图,ratio=2 等效于 2000 张图。
num_iter_per_epoch = math.ceil(
    len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
total_iters = int(opt['train']['total_iter'])
total_epochs = math.ceil(total_iters / num_iter_per_epoch)
  • 每个 epoch 的迭代次数 =数据集长度 × 放大比例 / (每 GPU 批大小 × GPU 总数)
    例如:1000 张图 × 2(放大) / (16 × 4 GPU) = 31.25 → 32 次迭代/epoch。
  • 总 epoch 数 = 总迭代次数 / 每 epoch 迭代次数

elif phase.split('_')[0] == 'val':

若配置文件为:

datasets:
train:         # 训练集
 name: DIV2K
val_Set5:      # 验证集1
 name: Set5
val_Set14:     # 验证集2  
 name: Set14

.split('_')->以下划线 _ 为分隔符拆分字符串,返回列表。
split('_')[0]->取拆分后的第一个部分(即前缀)

由于本项目的配置文件相应的部分为val而不是val_set5等,所以可是直接改为elif phase== 'val':

def load_resume_state(opt):
    resume_state_path = None
    if opt['auto_resume']:
        state_path = osp.join('experiments', opt['name'], 'training_states')
        if osp.isdir(state_path):
            states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
            if len(states) != 0:
                states = [float(v.split('.state')[0]) for v in states]
                resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
                opt['path']['resume_state'] = resume_state_path
    else:
        if opt['path'].get('resume_state'):
            resume_state_path = opt['path']['resume_state']

    if resume_state_path is None:
        resume_state = None
    else:
        device_id = torch.cuda.current_device()
        resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
        check_resume(opt, resume_state['iter'])
    return resume_state

这个函数的作用是加载训练中断时的恢复状态(如模型权重、优化器状态、当前迭代次数等),支持自动检测和手动指定恢复路径两种模式。

resume_state:恢复状态的字典(如未找到则返回 None)。

check_resume(opt, resume_state['iter'])

  • 调用 check_resume 检查恢复的迭代次数是否合法(如不超过配置的总迭代次数)

states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))

  • scandir():扫描目录
  • suffix='state':只匹配以.state结尾的文件
  • recursive=False:不递归扫描子目录
  • full_path=False:返回文件名而非完整路径
  • 返回文件名列表,如:['5000.state', '10000.state', '15000.state']
if len(states) != 0:
                states = [float(v.split('.state')[0]) for v in states]
                resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
                opt['path']['resume_state']= resume_state_path

这段代码实现的功能是保存最新的状态文件路径。

def train_pipeline(root_path):
    # parse options, set distributed setting, set ramdom seed
    opt, args = parse_options(root_path, is_train=True)
    opt['root_path'] = root_path
    # 启用CuDNN的自动优化器以加速训练。
    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    # load resume states if necessary
    resume_state = load_resume_state(opt)
    # mkdir for experiments and logger
    # 如果是首次训练(非恢复训练),创建实验目录和日志目录。
    if resume_state is None and opt['rank'] == 0:
        make_exp_dirs(opt)
        if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0:
            mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))

    # copy the yml file to the experiment root
    copy_opt_file(args.opt, opt['path']['experiments_root'])

    # WARNING: should not use get_root_logger in the above codes, including the called functions
    # Otherwise the logger will not be properly initialized
    # 初始化日志记录器,记录环境和配置信息。
    log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
    logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
    logger.info(get_env_info())
    logger.info(dict2str(opt))
    # initialize wandb and tb loggers
    tb_logger = init_tb_loggers(opt)
    # create train and validation dataloaders
    result = create_train_val_dataloader(opt, logger)
    train_loader, train_sampler, val_loaders, total_epochs, total_iters = result

    # create model
    model = build_model(opt)

    if resume_state:  # resume training
        model.resume_training(resume_state)  # handle optimizers and schedulers
        logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.")
        start_epoch = resume_state['epoch']
        current_iter = resume_state['iter']
    else:
        start_epoch = 0
        current_iter = 0

    # create message logger (formatted outputs)
    msg_logger = MessageLogger(opt, current_iter, tb_logger)

    # dataloader prefetcher
    # 创建消息日志记录器,用于格式化输出训练信息。
    prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
    if prefetch_mode is None or prefetch_mode == 'cpu':
        prefetcher = CPUPrefetcher(train_loader)
    elif prefetch_mode == 'cuda':
        prefetcher = CUDAPrefetcher(train_loader, opt)
        logger.info(f'Use {prefetch_mode} prefetch dataloader')
        if opt['datasets']['train'].get('pin_memory') is not True:
            raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
    else:
        raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.")


    # training
    logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
    data_timer, iter_timer = AvgTimer(), AvgTimer()
    start_time = time.time()

    for epoch in range(start_epoch, total_epochs + 1):
        train_sampler.set_epoch(epoch)
        prefetcher.reset()
        train_data = prefetcher.next()

        while train_data is not None:
            data_timer.record()

            current_iter += 1
            if current_iter > total_iters:
                break
            # update learning rate
            model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_iter)

            iter_timer.record()
            if current_iter == 1:
                # reset start time in msg_logger for more accurate eta_time
                # not work in resume mode
                msg_logger.reset_start_time()
            # log
            if current_iter % opt['logger']['print_freq'] == 0:
                log_vars = {'epoch': epoch, 'iter': current_iter}
                log_vars.update({'lrs': model.get_current_learning_rate()})
                log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
                log_vars.update(model.get_current_log())
                msg_logger(log_vars)

            # save models and training states
            if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(epoch, current_iter)

            # validation
            if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
                if len(val_loaders) > 1:
                    logger.warning('Multiple validation datasets are *only* supported by SRModel.')
                for val_loader in val_loaders:
                    model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])

            data_timer.start()
            iter_timer.start()
            train_data = prefetcher.next()
        # end of iter

    # end of epoch

    consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
    logger.info(f'End of training. Time consumed: {consumed_time}')
    logger.info('Save the latest model.')
    model.save(epoch=-1, current_iter=-1)  # -1 stands for the latest
    if opt.get('val') is not None:
        for val_loader in val_loaders:
            model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
    if tb_logger:
        tb_logger.close()

实现端到端的训练流程,包括配置解析、环境初始化、数据加载、模型训练、验证和日志记录。

输出:训练完成的模型和日志文件

1.初始化阶段

opt, args = parse_options(root_path, is_train=True)
opt['root_path'] = root_path
torch.backends.cudnn.benchmark = True
  • root_path:项目根目录路径(通常通过os.path.abspath(__file__)动态获取
  • opt:解析配置文件生成的
  • torch.backends.cudnn.benchmark = True:启用CuDNN加速
    • 训练速度提升约10-30%(尤其对CNN模型)
    • 输入尺寸变化频繁时(如可变分辨率),建议关闭以避免性能下降
  • args:命令行参数(如覆盖配置文件的临时参数)

2.恢复训练处理:

resume_state = load_resume_state(opt)
if resume_state is None and opt['rank'] == 0:
    make_exp_dirs(opt)
    mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
  • resume_state is None:没有找到可恢复的状态(全新训练/首次训练)

  • opt['rank'] == 0:仅由主GPU进程(rank=0)执行目录操作,避免多进程竞争

  • make_exp_dirs:创建实验所需的核心目录结构

典型结构为:

experiments/
└── exp1/ # 实验名称
├── models/ # 保存模型权重(如latest.pth, best.pth)
├── training_states/ # 训练状态文件(如10000.state)
└── log/ # 训练日志文件

  • mkdir_and_rename:TensorBoard目录,如果root_path/tb_logger/name这个路径不存在就创建,如果存在就重命名为带时间戳的形式,如:root_path/tb_logger/name_20250430

3.日志系统初始化

log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
logger = get_root_logger(log_file=log_file)
tb_logger = init_tb_loggers(opt)
  • 文件日志(记录训练过程细节)
  • TensorBoard/WandB日志(可视化指标)

4.数据加载与模型构建

5.训练恢复处理

6.训练主循环

for epoch in range(start_epoch, total_epochs + 1):
    train_sampler.set_epoch(epoch)  # 分布式训练关键步骤
    prefetcher.reset()
    while train_data is not None:
        current_iter += 1
        model.update_learning_rate(current_iter)
        model.feed_data(train_data)
        model.optimize_parameters(current_iter)
  • train_sampler.set_epoch(epoch):在分布式训练中,DistributedSampler需要知道当前epoch以重新打乱数据分片。确保不同GPU在不同epoch看到不同的数据组合。
    (应该是随机种子固定了,所以每一GPU每一epoch得到的数据分片都是固定的,从而需要知道当前的epoch,避免看到重复的数据切片)

  • prefetcher.reset():清除预取缓冲区,准备新一轮数据加载

    预取模式:

  • CPUPrefetcher:多线程将数据从磁盘预读到CPU内存,适合数据量小或CPU瓶颈。

  • CUDAPrefetcher:直接将数据预取到GPU显存,大数据/高性能GPU。

7.训练过程监督

8.训练收尾

  • 保存最终模型
  • 关闭tensorboard写入

接下来就是看它所有导入的函数的具体实现:

from basicsr.data import build_dataloader, build_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model
from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
from basicsr.utils.options import copy_opt_file, dict2str, parse_options
  1. from basicsr.data import build_dataloader

  2. from basicsr.data import build_dataset

  3. from basicsr.data.data_sampler import EnlargedSampler

  4. from basicsr.data.prefetch_dataloader import CPUPrefetcher

  5. from basicsr.data.prefetch_dataloader import CUDAPrefetcher

  6. from basicsr.models import build_model

def build_model(opt):
    """Build model from options.

    Args:
        opt (dict): Configuration. It must contain:
            model_type (str): Model type.
    """
    opt = deepcopy(opt)
    model = MODEL_REGISTRY.get(opt['model_type'])(opt)
    logger = get_root_logger()
    logger.info(f'Model [{model.__class__.__name__}] is created.')
    return model
  • MODEL_REGISTRY:模型注册表(通常为全局字典)

    • 示例注册表:
    MODEL_REGISTRY = {
     'SRGAN': SRGANModel,
     'EDSR': EDSRModel,
     'MambaIR': MambaIRModel
    }
    
    • 工作流程:
    1. 通过 opt['model_type'] 获取模型类(如 SRGANModel)(SRGANModel,EDSRModel,MambaIRModel都是类!)
    2. 实例化模型并传入配置 opt
    • 注册表实现示例:

      from basicsr.utils.registry import Registry
      
      MODEL_REGISTRY = Registry('model')
      
      @MODEL_REGISTRY.register()
      class SRGANModel(nn.Module):
       def __init__(self, opt):
           super().__init__()
           self.generator = build_generator(opt)
           self.discriminator = build_discriminator(opt)
      
      # 注册其他模型
      @MODEL_REGISTRY.register()
      class MambaIRModel(nn.Module): ...
      
  1. from basicsr.utils.options import copy_opt_file

  2. from basicsr.utils.options import dict2str

  3. from basicsr.utils.options import parse_options

def parse_options(root_path, is_train=True):
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, default='../options/train/train_MambaIR_SR_x2.yml', help='Path to option YAML file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
    parser.add_argument('--auto_resume', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--local-rank', type=int, default=0) # for pytorch 2.0
    parser.add_argument('--local_rank', type=int, default=0) # for pytorch < 2.0
    parser.add_argument('--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
    args = parser.parse_args()

    # parse yml to dict
    with open(args.opt, mode='r') as f:
        opt = yaml.load(f, Loader=ordered_yaml()[0])

    # distributed settings
    if args.launcher == 'none':
        opt['dist'] = False
        print('Disable distributed.', flush=True)
    else:
        opt['dist'] = True
        if args.launcher == 'slurm' and 'dist_params' in opt:
            init_dist(args.launcher, **opt['dist_params'])
        else:
            init_dist(args.launcher)
    opt['rank'], opt['world_size'] = get_dist_info()
    # 获取当前进程的rank和总进程数

    # random seed
    seed = opt.get('manual_seed')
    if seed is None:
        seed = random.randint(1, 10000)
        opt['manual_seed'] = seed
    set_random_seed(seed + opt['rank'])

    # force to update yml options
    if args.force_yml is not None:
        for entry in args.force_yml:
            # now do not support creating new keys
            keys, value = entry.split('=')
            keys, value = keys.strip(), value.strip()
            value = _postprocess_yml_value(value)
            eval_str = 'opt'
            for key in keys.split(':'):
                eval_str += f'["{key}"]'
            eval_str += '=value'
            # using exec function
            exec(eval_str)

    opt['auto_resume'] = args.auto_resume
    opt['is_train'] = is_train

    # debug setting
    if args.debug and not opt['name'].startswith('debug'):
        opt['name'] = 'debug_' + opt['name']
      
    # 自动根据当前机器设置GPU数量。
    if opt['num_gpu'] == 'auto':
        opt['num_gpu'] = torch.cuda.device_count()

    # datasets
    for phase, dataset in opt['datasets'].items():
        # for multiple datasets, e.g., val_1, val_2; test_1, test_2
        phase = phase.split('_')[0]
        dataset['phase'] = phase
        if 'scale' in opt:
            dataset['scale'] = opt['scale']

    # paths
    for key, val in opt['path'].items():
        if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
            opt['path'][key] = osp.expanduser(val)

    if is_train:
        experiments_root = osp.join(root_path, 'experiments', opt['name'])
        opt['path']['experiments_root'] = experiments_root
        opt['path']['models'] = osp.join(experiments_root, 'models')
        opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
        opt['path']['log'] = experiments_root
        opt['path']['visualization'] = osp.join(experiments_root, 'visualization')

        # change some options for debug mode
        if 'debug' in opt['name']:
            if 'val' in opt:
                opt['val']['val_freq'] = 8
            opt['logger']['print_freq'] = 1
            opt['logger']['save_checkpoint_freq'] = 8
    else:  # test
        results_root = osp.join(root_path, 'results', opt['name'])
        opt['path']['results_root'] = results_root
        opt['path']['log'] = results_root
        opt['path']['visualization'] = osp.join(results_root, 'visualization')

    return opt, args
  • args:命令行参数对象
  1. 命令行参数解析(argparse)

    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, default='../options/train/train_MambaIR_SR_x2.yml')
    parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none')
    parser.add_argument('--auto_resume', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--local-rank', type=int, default=0)  # PyTorch 2.0+
    parser.add_argument('--local_rank', type=int, default=0)  # PyTorch <2.0
    parser.add_argument('--force_yml', nargs='+', default=None)
    args = parser.parse_args()
    
  • -opt:主配置文件路径

  • --launcher:分布式启动方式

    • none 使用场景:单机单卡/多卡(非分布式)。技术实现:直接运行脚本
    • pytorch 使用场景:单机多卡/多机多卡分布式训练。技术实现:使用torch.distributed.launchtorchrun
    • slurm 使用场景:多机多卡集群训练。技术实现:通过SLURM作业调度系统启动
    维度nonepytorchslurm
    底层技术单进程torch.distributedtorch.distributed + Slurm
    资源管理用户手动控制用户手动控制Slurm自动分配
    扩展性低(单机级)中(多机级)高(集群级)
    典型命令示例python train.pytorchrun --nproc_per_node=4 train.pysbatch job_script.slurm
  • --auto_resume:自动恢复标志。

    • 动作类型:store_true(当用户在命令行中指定该参数时,--auto_resume会被设为 True,否则为 False。)
    • 示例:python train.py --auto_resume
  • --debug:调试标志(同理)

  • --force_yml: 动态配置覆盖。

    • nargs:+表示接受1个或多个参数。
    • 通过exec动态执行赋值,如:opt["train"]["lr"] = 0.001
    • 优先级最高:--force_yml命令行参数(eg:--force_yml "train:lr=0.01")>命令行直接参数(eg:--launcher slurm)>YAML配置文件默认值
  • args

    • 包含了所有通过命令行传入的参数及其对应的值。

    • 未指定的参数使用默认值

    • 典型args对象:

      • Namespace(
         opt='configs/train.yml',
         launcher='pytorch',
         auto_resume=True,
         debug=False,
         local_rank=0,
         force_yml=None
        )
        
    • 示例1:当运行脚本时输入:python train.py --launcher pytorch 。得到的args对象为Namespace(launcher=pytorch)

    • 示例2:当运行脚本时输入:python train.py -opt my_config.yml --debug。得到的args对象为:
      Namespace(opt='my_config.yml', debug=True)

    • 属性访问方式:
      print(args.opt) # 输出: my_config.yml
      print(args.debug) # 输出: True

    • 为什么需要 args 而不用直接读YAML?:

      • 动态性:允许运行时修改配置(如临时调大学习率)
      • 灵活性:同一份YAML可衍生不同实验(通过不同命令行参数)
      • 可读性args.lrconfig['train']['lr'] 更直观
    • 如何查看所有可用参数?:python train.py --help # 自动生成帮助信息

    • yaml配置文件的关系:

    维度YAML配置文件命令行args
    存储形式结构化文本文件(层次化键值对)扁平化命名空间对象
    主要用途定义默认训练参数(模型结构、数据集路径等)运行时临时覆盖配置/开关控制
    修改频率低频(不同实验间修改)高频(单实验多次调试时修改)
    适用场景需要版本控制的固定参数临时性调试参数

    优先级:命令行参数>YAML配置

  • debug模式:

    • 命令行直接启动:python train.py --debug # 添加–debug参数即可

    • 组合其他参数:(调试模式+修改批量大小+提高日志频率)

      python train.py --debug --force_yml "datasets:train:batch_size_per_gpu=2" "logger:print_freq=1"

  • 其它训练方式:

    • 基础训练python train.py -opt options/train.yml
    • 分布式训练(2GPU):
    torchrun --nproc_per_node=2 train.py \
     --launcher pytorch \
     --opt options/train.yml \
     --auto_resume
    
    • python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 basicsr/train.py -opt options/train/mambairv2/train_MambaIRv2_SR_x2.yml --launcher pytorch
      
with open(args.opt, mode='r') as f:
    opt = yaml.load(f, Loader=ordered_yaml()[0])
  • 使用 yaml.load 读取yaml配置文件,ordered_yaml() 确保字段顺序保留

  • ordered_yaml()[0]
    自定义的加载器,用于保持 YAML 中字段的原始顺序(普通 yaml.load 可能返回无序字典)

  • 为什么需要保持顺序:打印或日志输出时,有序字典更易于定位参数。

  • 典型 YAML 配置文件示例

    # options/train.yml
    name: "MambaIR_SR_x2"  # 实验名称
    model:
      type: "MambaIR"      # 模型类型
      scale: 2             # 超分辨率缩放因子
    datasets:
      train:
        dataroot_gt: "data/DIV2K/train"  # 训练集路径
        batch_size: 16
    
  • 解析后 opt 的结构:

    {
        'name': 'MambaIR_SR_x2',
        'model': {'type': 'MambaIR', 'scale': 2},
        'datasets': {'train': {'dataroot_gt': 'data/DIV2K/train', 'batch_size': 16}}
    }
    

opt['rank'], opt['world_size'] = get_dist_info()

  • 获取当前进程的rank和总进程数
  • Rank是一个从 0 开始的整数,用于标识分布式训练中每个进程的唯一身份,决定了数据分配、通信顺序和任务分工。
  • rank的作用:
    • 数据分配:决定当前进程处理哪部分数据(如DDP中,rank=0的进程可能负责验证日志记录)
    • 梯度同步
    • 模型保存:通常仅由rank=0的进程保存模型,避免重复存储
    • 日志控制:仅rank=0进程打印日志,避免多进程重复输出
  • rank的取值规则:
    • 单机多卡(1台机器,N张GPU):rank ∈ [0, N-1]
    • 多机多卡(M台机器,每台N张GPU):全局rank范围:[0, M×N-1]。每台机器的local_rank(即单机内的GPU编号)范围:[0, N-1]
if args.force_yml is not None:
    for entry in args.force_yml:
        keys, value = entry.split('=')
        eval_str = 'opt' + ''.join(f'["{k}"]' for k in keys.split(':')) + '=value'
        exec(eval_str)  # 执行动态赋值
  • 示例--force_yml "train:lr=0.001" 会修改 opt['train']['lr'] 的值。
parser = argparse.ArgumentParser()  # 创建参数解析器
args = parser.parse_args()         # 解析命令行参数
  • 从命令行(如 python train.py --lr 0.01 --batch_size 32)中读取用户输入的参数,并将其转换为 Python 对象供程序使用。
if args.debug and not opt['name'].startswith('debug'):
    opt['name'] = 'debug_' + opt['name']  # 自动添加debug前缀# 修改实验名称
    #如name: "MambaIR_SR_x2"--->name: "debug_MambaIR_SR_x2"
    # 调试模式下调整验证频率和日志频率
    if 'val' in opt:
        opt['val']['val_freq'] = 8
    opt['logger']['print_freq'] = 1
  • 用于 调试模式(Debug Mode)的自动化配置,当用户通过命令行参数 --debug 启用调试模式时,它会自动修改实验配置以方便开发调试
for phase, dataset in opt['datasets'].items():
    phase = phase.split('_')[0]  # 处理多数据集分支(如val_1, val_2)
    dataset['phase'] = phase
    if 'scale' in opt:
        dataset['scale'] = opt['scale']  # 统一超分辨率缩放因子
  • 这里的phase是什么?
    opt['datasets']是一个字典,.items()来获取键值对。phase就是键key

  • 作用:
    datasets中可能长这样:

    datasets:
      train: 
        dataroot_gt: 'data/DIV2K/train'
      val_1:  
        dataroot_gt: 'data/DIV2K/val'
      val_2: 
        dataroot_gt: 'data/Set5'
      test_1: 
      dataroot_gt: 'data/Set5'
    

    for phase, dataset in opt['datasets'].items()得到的phasetrain,val_1,val_2,test_1。dataset得到的是dataroot_gt
    那么经过处理后得到的:

    datasets:
      train: 
        dataroot_gt: 'data/DIV2K/train'
        	phase:train
      val_1:  
        dataroot_gt: 'data/DIV2K/val'
        	phase:val
      val_2: 
        dataroot_gt: 'data/Set5'
        	phase:val
      test_1: 
      	dataroot_gt: 'data/Set5'
    		phase:test
    
    

    从而实现区分训练/验证/测试的数据,便于后续环节的进行。

opt&args&-opt

维度opt (配置字典)args (命令行对象)-opt (命令行参数)
数据类型嵌套字典命名空间对象(flat)字符串(文件路径)
来源YAML配置文件命令行输入命令行输入(-opt参数)
优先级基础配置可覆盖opt仅决定opt的加载路径
典型用途定义模型结构、超参数等临时调整实验行为指定配置文件路径

opt(配置字典)

  • 来源:从 YAML配置文件 加载(如 train_MambaIR_SR_x2.yml),通过 yaml.load() 解析为Python字典

  • 内容:包含模型、数据集、训练超参数等所有静态配置

  • 特点

    • 层级化结构(嵌套字典),适合管理复杂配置。
    • 通常与代码解耦,便于实验复现。
  • 示例

    # options/train.yml
    name: "exp1"
    model:
      type: "MambaIR"
      scale: 2
    

    解析后:

    opt = {
        'name': 'exp1',
        'model': {'type': 'MambaIR', 'scale': 2}
    }
    

args(命令行参数对象)

  • 来源:通过 argparse.ArgumentParser() 解析的命令行参数

  • 内容:用户运行时临时指定的动态覆盖选项

  • 特点

    • 扁平结构(单层属性),适合快速调整关键参数。
    • 优先级高于 opt(可覆盖YAML配置)。
  • 示例

    python train.py --launcher pytorch --debug
    

    解析后:

    args.launcher = "pytorch"
    args.debug = True
    

-opt(命令行参数中的配置路径)

  • 来源args 的一个特殊字段,通过 parser.add_argument('-opt', ...) 定义。

  • 作用:指定YAML配置文件的路径,是连接 argsopt 的桥梁。

  • 示例

    python train.py -opt options/train.yml  # 指向YAML文件
    

    代码中通过 args.opt 获取路径后加载配置:

    with open(args.opt) as f:
        opt = yaml.load(f)  # 生成opt字典
    

from basicsr.utils import

  1. AvgTimer

  2. MessageLogger

  3. check_resume

  4. get_env_info

  5. get_root_logger

  6. get_time_str

  7. init_tb_logger

  8. init_wandb_logger

  9. make_exp_dirs

  10. mkdir_and_rename

  11. scandir

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值