python训练设置断点方便中断后继续训练

本文介绍了在模型训练过程中如何实现断点续训的方法。通过保存每个epoch后的模型参数,并在下一次训练开始时加载这些参数,可以有效地解决因意外中断而导致的训练进度丢失问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


前言

因为有些人是使用服务器去跑程序,不是用的自己的主机,所以不可避免的出现程序被其他使用者中断或者多人共用导致内存爆掉而程序被kill,这个时候在程序中添加断点就显得尤为重要。


提示:以下是本篇文章正文内容,下面案例可供参考

一、断点是什么

其实没有那么难理解,其实原理就是在程序运行时,每训练一个epoch就把模型训练后得到的各种参数保存下来,下次训练直接加载你保存各种参数的这个模型就可以继续接着训练了。

二、步骤

1.保存模型

找到你网络中保存权值的语句,我们就在这附近添加语句即可:

#保存权值语句
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
            torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)))
#保存中断时模型的语句,我设置的是每训练完一轮就保存一下
if (epoch + 1) % 1 == 0 or epoch + 1 == UnFreeze_Epoch:
                state = {'model':model.state_dict(),'optimizer':optimizer.state_dict(),'eppoch':epoch}
                torch.save(state,log_dir)

2.在训练前加载保存的模型

代码如下(示例):

test_flag = True
log_dir = '.../*.pth'
if test_flag:
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        Init_Epoch = checkpoint['eppoch']
        print('加载中断处模型成功,继续训练!')
else:
        Init_Epoch = 0
        print('无保存模型,从头开始训练! ')

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值