torch保存和加载 模型、参数

保存整个模型,包括模型结构+参数

torch.save(self.model.cpu(), output_path)

相应的加载是

model = torch.load(output_path)

缺点:这种方式保存的模型只能在相同的环境中使用,因为它依赖于模型定义的代码。如果您在不同的环境中使用这种方式保存的模型,可能会出现错误。(相同的环境指的是运行代码的环境,包括操作系统、Python 版本、PyTorch 版本以及其他依赖库的版本都相同。)

保存模型的参数

torch.save(self.model.state_dict(), output_path)

相应的加载是

state_dict = torch.load(save_path)
bert_model = Architecture() # 实例化模型类
bert_model.load_state_dict(state_dict)

优点:想要在不同的环境中使用保存的模型,建议使用 state_dict 来保存和加载模型的参数。这样,您只需要在新环境中定义相同的模型结构,然后使用 load_state_dict 函数加载保存的参数即可。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值