中断网络训练后,如何根据保存的权重文件继续训练

在对网络进行训练的时候,尤其是一些结构复杂的大型网络模型,一般需要训练很长时间才能得到比较满意的结果,要是从零开始训练那就是更漫长的等待了!在训练的过程中总可能发生一些预想不到的意外从而造成训练的终止,所以学会如何在网络训练中断后,根据保存的权重文件继续训练就显得至关重要!!!!!

下面就来具体讲讲如何断点续训

一、保存模型

可以看我的这篇博客,看完后应该会明白的,如果有不明白的欢迎评论区留言,互相交流!

二、加载模型

使用 torch.load函数 进行加载

model.load_state_dict(torch.load('权重文件所在的路径', map_location='要存放的设备'))

三、保存权重文件

一般地,权重文件中需要保存:网络模型权重、优化器权重、以及epoch,便于继续训练

我一般会设置条件保存权重文件,权重文件相对占用的空间比较大,可以设置 每间隔几个epoch 保存一次

if epoch % 5 == 0:
    checkpoint = {
         "model":model.state_dict(),
         "optimizer":optimizer.state_dict(),
         "epoch":epoch
      }         # 要保存的权重文件内容

    # 创建保存权重文件所在的文件夹
    if not os.path.isdir('./checkpoint'):
           os.makedirs("./checkpoint")
    # 使用 torch.save() 函数保存权重文件到指定的路径下
    torch.save(checkpoint, './checkpoint/ckpt_%d.pth' % int(epoch))
    # 或者写成:torch.save(checkpoint, './checkpoint/ckpt_{}.pth '.format(int(epoch)))
    # torch.save(checkpoint, f'./checkpoint/ckpt_{int(epoch)}.pth')
这样就可以在checkpoint文件夹下看到每间隔5个epoch保存的权重文件

四、在断点处继续训练

注意 start_epoch 的设定,以保证再次训练时,epoch的次数匹配

start_epoch = 0      # strat_epoch的值一般默认为0,表示从头开始进行训练
# Resume = False     # 用于指示是否继续训练   取值为逻辑值: False:初次训练; True:继续训练
Resume = True
# 判断是否继续训练 
if Resume:
    path_checkpoint = "./checkpoint/ckpt_5.pth"  # 断点权重文件的所在路径
    checkpoint = torch.load(path_checkpoint)  # 加载保存的断点文件

    model.load_state_dict(checkpoint['model'])  # 加载模型可学习参数

    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    start_epoch = checkpoint['epoch']                   # 设置继续上次训练的epoch     

最后,感谢博主

大梦冲冲冲icon-default.png?t=N7T8https://blog.csdn.net/qq_37844044

的博文:

pytroch网络训练中断后,根据断点再次训练!!!_pytorch训练中断后怎么恢复-CSDN博客

给予我灵感 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值