pytorch学习系列(10):断点续训

训练中难免遇到停电、宕机等情况。如果没有设置断点续训,那么就得从头开始,非常浪费时间和能源。
要实现断点续训,首先要保存之前训练的模型,除此之外,还有优化器的状态。这是因为目前很多优化器的参数是会随着训练而不断变化的,具体可见这篇博客:pytorch学习系列(4):常用优化算法
首先设置两个参数:
start_epoch:续训时开始训练的epoch数
resume:是否实施断点续训

start_epoch = 0
resume = True

模型和优化器设置完成后,加入如下代码,载入checkpoint再进行训练,

if opt.resume:
   if os.path.isfile('checkpoint'):
       checkpoint = torch.load('checkpoint')
       start_epoch = checkpoint['epoch'] + 1
       model.load_state_dict(checkpoint['model'])
       optimizer.load_state_dict(checkpoint['optimizer'])
       print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
   else:
       print("=> no checkpoint found")

每一轮训练保存模型后,加上如下代码保存epoch、model和optimizer的信息到checkpoint文件中。

checkpoint = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        torch.save(checkpoint,'checkpoint')

这种方式也支持tensorboardX可视化的断点续图,只不过会多产生一个日志文件,只记录了断点后的情况。原来的日志文件中还会继续记录断点后的数据。

  • 8
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值