pytorch训练中断后,如何在之前的断点处继续训练

本文讲述了如何在模型训练过程中遇到意外中断时,通过定期保存模型权重并加载断点继续训练的方法。介绍了代码实现,包括模型的保存(如torch.save)和加载(torch.load)步骤,确保训练的连续性和效率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

我们在训练模型的时候经常出现各种问题导致训练中断,比方说断电,或者关机之类的导致电脑系统关闭,从而将模型训练中断,那么如何在模型中断后,能够保留之前的训练结果不被丢失,同时又可以继续之前的断点处继续训练?

首先在代码离需要保存模型,比方说我们模型设置训练5000轮,那么我们可以选择每100轮保存一次模型,这样的话,在训练的过程中就能保存下100,200,300.。。。等轮数时候的模型,那么当模型训练到400轮的时候突然训练中断,那么我们就可以通过加载400轮的参数来进行继续训练,其实这个过程就类似在预训练模型的基础上进行训练。下面简单粗暴上代码:

1、保存模型

torch.save(checkpoint, checkpoint_path)

其中checkpoint其实保存的就是模型的一些参数,比方说下面这种字典形式的保存所需的模型参数:

checkpoint = {
    'model': model_state_dict,
    'generator': generator_state_dict,
    'opt': model_opt,
    'optim': optim,
}

checkpoint_path则是表示保存的模型

checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)

save_checkpoint_steps是保存的间隔轮数,step是保存的轮数,比方说save_checkpoint_steps=100,那么step的取值就是100,200,300,400等,下面的代码解释step的取值由来。

if step % self.save_checkpoint_steps != 0:
    return
chkpt, chkpt_name = self._save(step)

其中_save函数就是实现了前面checkpoint的内容的保存。

模型的保存设置就此结束。

2、模型的加载

假如此时模型训练中断了,我们得在代码里设置一个参数,这个参数用来查找确定当前路径下是否有已存在得模型。

# 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        generator.load_state_dict(checkpoint['generator'])
        start_epoch = checkpoint['model_opt']
        optim=checkpoint['optim']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存模型,将从头开始训练!')

或者设置一个变量train_from,若赋值已有模型得路径,则继续训练;若为None,那么从头训练。这块代码既可以用于训练中断,又可以用于使用预训练模型。

if opt.train_from:#是否存在预训练模型
    logger.info('Loading checkpoint from %s' % opt.train_from)
    checkpoint = torch.load(opt.train_from)#加载预训练模型的检查点
    model_opt = checkpoint['opt']
else:
    checkpoint = None
    model_opt = opt

加油,come on!

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序小K

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值