[Pytorch] 模型的保存torch.save()与加载torch.load()

在模型训练验证之后,可以把训练好的模型进行保存
后续进行测试的时候,可以直接加载模型进行测试

保存模型

将模型保存下来,如果后续有数据了可以继续训练,或者直接加载进行测试。

  1. 只保存模型的参数(一般只保存参数即可)
torch.save(model.state_dict(), PATH)  # 保存模型参数
  1. 保存整个模型(较大)
torch.save(model,PATH)  # 保存整个模型

加载模型

  1. 只加载模型的参数(先创建模型、并加载参数,再恢复得到模型)
model = MyModel().to(device)
checkpoint = torch.load(config['save_path'])  # 先加载参数
model.load_state_dict(checkpoint)  # 再让模型加载参数, 恢复得到模型
  1. 加载整个模型
model = torch.load(PATH)

在torch.load()中,常常会使用map_location进行cpu和gpu的转化。
(1)【GPU->CPU】
比如模型训练的时候是在GPU上进行并保存的,测试的时候却想在CPU上进行训练:

model = torch.load(PATH, map_location='cpu')

(2)【CPU->GPU】
转为GPU需要指明在哪块GPU上,例如转到第0块GPU:

model = torch.load(PATH, map_location=lambda storage, loc: storage.cuda(0))

(3)【GPU->GPU】
不同块GPU的转换,例如第1块转到第0块:

torch.load(PATH, map_location={'cuda:1':'cuda:0'})

Note:如果只保存了模型参数,就加载模型参数;如果保存了整个模型,就加载整个模型;上述两组一一对应。

参考:

  1. pytorch 状态字典:state_dict使用详解:https://blog.csdn.net/Bruce_0712/article/details/111990905
  2. torch.load_state_dict()函数的用法总结:https://blog.csdn.net/ChaoMartin/article/details/118686268
  3. pytorch cpu与gpu load时相互转化 torch.load(map_location=):https://blog.csdn.net/bc521bc/article/details/85623515
  • 4
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值