Pytorch:epoch保存和断点训练

        为什么要保存epoch?如果模型比较大,在训练时可能会由于某些意外原因导致训练还没有完全完成就终止,对付这种情况,可以通过每隔一定数量的epoch就保存一次模型参数,下次如果出现训练终止的情况时,就可以加载最新的模型来恢复训练,而不用从头开始。

        在首次训练时,先对模型进行保存,这里根据自己代码的实际情况进行划分,笔者以5个epoch保存为一个checkpoint,在保存时,一定要记得保存optimizer,否则,最终的loss结果会出现与预期不相符:

for epoch in range(start_epoch, epoches):
	train(train_dataloader, model, loss_func, optimizer, epoch)
	test(test_dataloader, model, loss_func)
	if (epoch + 1) % 5 == 0:
		print('epoch:', epoch + 1)
		print('learning rate:', optimizer.state_dict()['param_groups'][0]['lr'])
		checkpoint = {
			'model': model.state_dict(),
			'optimizer': optimizer.state_dict(),
			'epoch': epoch
		}
		torch.save(checkpoint, './checkpoint/ckpt_best_%s.pth' % (str(epoch + 1)))
		print('Saved all parameters!\n')

        保存的模型参数:

   

        从上次的断点处继续训练,笔者在这里列出两种方法,第一种是通过判断RESUME来确定是否恢复训练;第二种是直接加载上次的断点路径来继续训练。

第一种:

RESUME = True  #控制是否是恢复训练。False:初次训练;True:继续训练
epoches = 20  #训练了20次
start_epoch = -1
#模型断点的设置
if RESUME:
	print('-----------------------------')
	path_checkpoint = "./checkpoint/ckpt_best_5.pth"  # 断点路径
	checkpoint = torch.load(path_checkpoint)  # 加载断点
	model.load_state_dict(checkpoint['model'])  # 加载模型可学习参数
	optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
	start_epoch = checkpoint['epoch'] + 1  # 设置开始的epoch
	print('加载 epoch {} 成功!'.format(start_epoch))
	print('-----------------------------')
else:
	start_epoch = 0
	print('无保存模型,将从头开始训练!')

第二种:

if os.path.exists("./checkpoint/ckpt_best_10.pth"):
	checkpoint = torch.load("./checkpoint/ckpt_best_10.pth")
	model.load_state_dict(checkpoint['model'])
	optimizer.load_state_dict(checkpoint['optimizer'])
	start_epoch = checkpoint['epoch'] + 1
	print('加载 epoch {} 成功!'.format(start_epoch))
else:
	start_epoch = 0
	print('无保存模型,将从头开始训练!')

 

  • 0
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值