pytorch保存和加载模型 & checkpoint断点续传

pytorch保存和加载模型

保存模型
有两种方式保存模型
一、保存整个网络
保存整个神经网络的的结构信息和模型参数信息,save的对象是网络net。 后缀一般命名为.pkl

net = Net()
 
# 保存和加载整个模型
torch.save(net, 'model.pkl')
model = torch.load('model.pkl')

二、仅保存模型参数
当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。只保存神经网络的训练模型参数,save的对象是net.state_dict()。后缀一般命名为 .pt 或者 .pth

# 仅保存和加载模型参数(推荐使用)
torch.save(net.state_dict(), 'params.pth')
net.load_state_dict(torch.load('params.pth'))

以上加载的模型在进行预测前,要调用 model.eval() 方法来将 dropout 和 batch normalization 层设置为验证模型。否则,只会生成前后不一致的预测结果。

但是如果在train的过程中保存,即你训练还没完不得已先结束或者断电断网了之类的,那就需要加载和保存一个通用的检查点(Checkpoint),也就是断点续传。
checkpoint断点续传
当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict ,比如说优化器的 state_dict 也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding 层等等。

上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save 方法

#保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)
 
torch.save(checkpoint, 'checkpoints/ckpt_%s.pth' %(str(epoch)))
#加载
checkpoint = torch.load('checkpoints/ckpt_100.pth')  # 加载断点
 
model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
start_epoch = checkpoint['epoch']  # 设置开始的epoch
loss = checkpoint['loss']

#续算
#加载完后,根据后续步骤,调用 model.eval() 用于预测,model.train() 用于恢复训练。
model.eval()
# - 或者 -
model.train()
#加载之后再训练就直接从start_epoch+1开始了

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值