pytorch保存模型的坑

本文介绍了PyTorch在保存和加载模型过程中遇到的设备错误和模型错误,包括如何处理模型默认在GPU上的问题,解决不同PyTorch版本之间的兼容性问题,以及如何正确保存和恢复优化器的状态。总结建议在保存时将模型和优化器参数转到CPU,并根据训练需求决定是否保存优化器状态。
摘要由CSDN通过智能技术生成

pytorch中保存模型相关的函数有3个:

  • torch.save:利用python的pickle模块实现序列化并保存序列化后的object
  • torch.load:利用pickle将保存的object反序列化
  • torch.nn.Module.load_state_dict:通过反序列化得到的state_dict读取保存的训练参数

有两种方法保存模型:

1. torch.save(model, path) # 直接保存整个模型
2. torch.save(model.state_dict(), path) # 保存模型的参数

相应地有两种方法加载保存的模型:

1. model = torch.load(path) # 直接加载模型
2. model = Model()                         # 先初始化一个模型
   model.load_state_dict(torch.load(path)) # 再加载模型参数

看起来第一种方法更加简单

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值