关于Resume训练 精度对齐的思考

结论:如果想要精度对齐,需要Resume 正确的 state_dict,正确的学习率调度器,正确的恢复优化器的状态,同时还需要设置相同的随机种子。

讨论1、如果同时保存了state_dict_ema与state_dict,应该加载哪个?

结论1:加载state_dict_ema

什么是 state_dict_ema

state_dict_ema 是使用指数移动平均计算的模型权重。EMA 是一种平滑技术,常用于减小训练中的噪声和提高模型的泛化能力。EMA 权重通常比原始训练权重更加平滑,但在某些情况下,EMA 权重可能与标准权重有显著差异,尤其是在恢复训练时。为了确保训练精度的一致性,应该明确何时使用 state_dictstate_dict_ema。通常在恢复训练时,应使用标准的 state_dict 来恢复模型参数,除非你明确知道要继续使用 EMA 权重进行训练。

使用 EMA 的注意事项

  1. 恢复训练时:通常使用标准的 state_dict,因为这是在训练过程中直接更新的权重。
  2. 验证和测试时:可以使用 state_dict_ema 来评估模型性能,因为 EMA 权重可能具有更好的泛化能力。

讨论2、学习率调度器 

学习率调度器(Learning Rate Scheduler)是一种用于动态调整模型训练过程中学习率的机制。在训练深度学习模型时,学习率是一个非常重要的超参数,它控制着每次权重更新的步长。选择合适的学习率对模型的收敛速度和最终性能至关重要。学习率调度器可以根据预定的策略在训练过程中调整学习率,从而帮助模型更快更好地收敛。

为什么需要学习率调度器?

  1. 提高收敛速度:在训练开始时使用较大的学习率可以使模型快速接近最优解。
  2. 避免陷入局部最优:逐渐减小学习率可以帮助模型跳出局部最优解,找到全局最优解。
  3. 平滑收敛:在训练后期使用较小的学习率可以使模型的权重更新更精细,从而提高最终的模型性能。

常见的学习率调度策略

  1. 固定学习率(Constant Learning Rate):学习率在整个训练过程中保持不变。
  2. 阶梯下降(Step Decay):每隔固定的训练步数,学习率按固定比例衰减。
  3. 指数衰减(Exponential Decay):学习率按指数函数衰减。
  4. 余弦退火(Cosine Annealing):学习率按照余弦函数规律衰减。
  5. 自适应学习率(Adaptive Learning Rate):如 Cyclical Learning Rate (CLR) 或 Learning Rate Range Test (LRRT),根据模型的训练状态动态调整学习率。

讨论3、优化器状态 

在深度学习中,优化器的状态(optimizer state)是指优化器在训练过程中维护的一些内部变量和信息,这些变量和信息用于控制模型参数的更新过程。优化器状态通常包括动量、学习率调度器的状态、梯度的历史信息等。恢复训练时,除了恢复模型的参数,还需要恢复优化器的状态,以确保训练过程的连续性和一致性。

优化器状态的重要性

优化器状态在训练过程中非常重要,因为它包含了优化器的内部信息,这些信息决定了模型参数如何更新。忽略优化器状态会导致以下问题:

  1. 训练不稳定:如果不恢复动量等状态信息,训练过程可能变得不稳定,导致模型收敛速度变慢或无法收敛。
  2. 性能下降:学习率调度器的状态不正确会导致学习率调整不当,进而影响模型性能。
  3. 不连续的训练过程:恢复训练时,如果优化器状态不正确,会导致训练过程的不连续,影响最终的训练结果。

优化器状态包含什么

具体包含的内容取决于使用的优化器类型。以下是一些常见优化器的状态信息:

  1. SGD(随机梯度下降)优化器

    • 动量(momentum)
    • 上一次更新的梯度值
  2. Adam 优化器

    • 一阶矩估计(moving average of gradients)
    • 二阶矩估计(moving average of squared gradients)
    • 时间步(time step)

讨论4、随机种子

在计算机科学和机器学习中,随机种子(random seed)是用于初始化随机数生成器的一个固定值。设置随机种子可以确保随机数生成器在每次运行时产生相同的随机数序列,从而使得实验和结果可重复。这对于调试、结果对比和实验的再现性非常重要。

为什么使用随机种子?

  1. 结果可重复性:在训练机器学习模型时,很多过程(如数据打乱、权重初始化等)涉及随机性。设置随机种子可以确保这些过程每次运行时都一致,从而得到可重复的结果。
  2. 调试方便:在调试模型时,固定随机种子可以帮助你定位问题,因为每次运行时的随机数序列相同,问题的表现也会一致。
  3. 结果对比:在对比不同模型或算法的效果时,使用相同的随机种子可以确保对比的公平性,因为每个实验的初始条件一致。
import torch
import random
import numpy as np

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 设置随机种子
set_seed(42)

实战

 基于以上的讨论,在保存checkpoint时,也要加入对应信息。

def save_checkpoint(work_dir,
                    epoch,
                    model,
                    model_ema=None,
                    optimizer=None,
                    lr_scheduler=None,
                    keep_last=False,
                    step=None,
                    ):
    os.makedirs(work_dir, exist_ok=True)
    state_dict = dict(state_dict=model.state_dict())
    if model_ema is not None:
        state_dict['state_dict_ema'] = model_ema.state_dict()
    if optimizer is not None:
        state_dict['optimizer'] = optimizer.state_dict()
    if lr_scheduler is not None:
        state_dict['scheduler'] = lr_scheduler.state_dict()
    if epoch is not None:
        state_dict['epoch'] = epoch
        file_path = os.path.join(work_dir, f"epoch_{epoch}.pth")
        if step is not None:
            file_path = file_path.split('.pth')[0] + f"_step_{step}.pth"
    logger = get_root_logger()
    torch.save(state_dict, file_path)
    logger.info(f'Saved checkpoint of epoch {epoch} to {file_path.format(epoch)}.')
    if keep_last:
        for i in range(epoch):
            previous_ckgt = file_path.format(i)
            if os.path.exists(previous_ckgt):
                os.remove(previous_ckgt)

保存调用

                save_checkpoint(os.path.join(config.work_dir, 'checkpoints'),
                                epoch=epoch,
                                step=(epoch - 1) * len(train_dataloader) + step + 1,
                                model=accelerator.unwrap_model(model),
                                model_ema=accelerator.unwrap_model(model_ema),
                                optimizer=optimizer,
                                lr_scheduler=lr_scheduler
                                )

resume加载时

def load_checkpoint_net(checkpoint,
                    model,
                    model_ema=None,
                    optimizer=None,
                    lr_scheduler=None,
                    load_ema=False,
                    resume_optimizer=True,
                    resume_lr_scheduler=True
                    ):
    assert isinstance(checkpoint, str)
    ckpt_file = checkpoint
    checkpoint = torch.load(ckpt_file, map_location="cpu")
    if load_ema:
        state_dict = checkpoint['state_dict_ema']
    else:
        state_dict = checkpoint.get('state_dict', checkpoint)  # to be compatible with the official checkpoint
    # model.load_state_dict(state_dict)
    missing, unexpect = model.load_state_dict(state_dict, strict=False)
    if model_ema is not None:
        # 获取 model_ema 的状态字典
        model_ema_state_dict = model_ema.state_dict()
        # 过滤掉在 model_ema 中不存在的键
        pretrained_state_dict = {k: v for k, v in checkpoint['state_dict_ema'].items() if k in model_ema_state_dict}
        # 更新 model_ema 的状态字典
        model_ema_state_dict.update(pretrained_state_dict)
        model_ema.load_state_dict(model_ema_state_dict, strict=False)
    if optimizer is not None and resume_optimizer:
        optimizer.load_state_dict(checkpoint['optimizer'])
    if lr_scheduler is not None and resume_lr_scheduler:
        lr_scheduler.load_state_dict(checkpoint['scheduler'])
    logger = get_root_logger()
    if optimizer is not None:
        epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0])
        logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, '
                    f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.')
        return epoch, missing, unexpect
    logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.')
    return missing, unexpect

resume调用  一定要让load_ema=False

    if config.resume_from is not None and config.resume_from['checkpoint'] is not None:
        start_epoch, missing, unexpected = load_checkpoint_net(**config.resume_from,
                                                           model=model,
                                                           model_ema=model_ema,
                                                           optimizer=optimizer,
                                                           lr_scheduler=lr_scheduler,
                                                           )

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

开始学AI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值