pytorch note

1.模型保存与加载

1.1

#a、保存 推荐仅仅保存模型的state_dict
torch.save(model.state_dict(), MODELPATH) # .pt  .pth
#b、加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
#Pytorch保存的模型后缀一般是.pt或者.pth
#必须在加载模型后调用model.eval函数来将dropout及批归一化层设置为预测模式。如果不这么做结果出错。

1.2 a、保存临时模型用于预测或再训练

torch.save({
 'epoch': epoch,
 'model_state_dict': model.state_dict(),
 'optimizer_state_dict': optimizer.state_dict(),
 'loss': loss, ... },
 PATH)

当保存一个临时模型用于预测或再训练时,需要保存比state_dict更多的参数。包括优化器的state_dict,迭代次数epoch,最后一层迭代的loss及其他任何需要的参数。
 当保存多个组件时,将多个组件以字典的形式组织,然后用torch.savee()来序列化该字典。在Pytorch中常用.tar文件后缀表示这种模型。
b、加载

model = TheModelClass(*args, **kwargs) 
optimizer = TheOptimizerClass(*args, **kwargs)
 checkpoint = torch.load(PATH)
 model.load_state_dict(checkpoint['model_state_dict'])
 optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
 epoch = checkpoint['epoch']
 loss = checkpoint['loss']
 model.eval() #预测 # - or - model.train() #再训练

e.g.

save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'lr': args.lr,
                'optimizer' : optimizer.state_dict(),
            }, checkpoint=args.checkpoint)

转载于:https://www.cnblogs.com/yanghailin/p/11607080.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值