[深度学习笔记(3)]模型保存与加载


本系列是博主刚开始接触深度学习时写的一些笔记,写的很早了一直没有上传,趁着假期上传一下,作为分享,希望能帮助到你。

一、模型保存

保存模型/模型参数。

torch.save(obj, f, pickle_module = <module‘ ... ‘>,pickle_protocol=2)

        其中,obj是需要保存的对象,f是类文件对象或一个保存文件名的字符串,pickle_module指用于picking元数据和对象的模块,pickle_protocol指可以覆盖的默认参数。举例说明:

torch.save(model, ‘model.pt’)  #保存整个模型
Torch.save(model.state_dict(), ‘model.pt’)  #保存训练好的网络权重

二、模型加载

1.加载模型

torch.load(f, map_location=None,pickle_module = <module‘ pickle ‘ from ‘ ... ’>)

        其中,f是类文件对象或一个保存文件名的字符串,map_location指一个函数或字典规定如何映射存储设备,pickle_module指用于unpicking元数据和对象的模块(必须匹配序列化文件时的pickle_module)。

2.加载模型参数

torch.nn.Module.load_state_dict(state_dict, strict=True)

        其中,state_dict指保存parameters和persistent buffers的字典。只有包含了可学习参数的层(如卷积层、线性层等)和已注册的命令才有模型的state_dict入口。

举例说明:

#(1)
#保存整个模型
torch.save(model_object, ‘model.pth’)
#加载模型
model = torch.load(‘model.pth’)

#(2)
#保存参数
torch.save(model_object.state_dict(), ‘params.pth’)
#加载模型
model_object=model()
model_object.load_state_dict(torch.load(‘params.pth))

#(2)的模型效果非常差,解决方法:
#(2)plus:
#保存参数
torch.save(model_object.state_dict(), ‘params.pth’)
#加载模型
model_object=model()
model_object.load_state_dict(torch.load(‘params.pth))
model.eval() #固定dropout()和BN层

        其中,model.eval()的作用是固定dropout()和BN层。


总结

        以上就是今天要讲的内容,本文介绍了模型保存与加载的详细代码实现,希望能够帮助到你。如有错误,请及时指出,我们一起进步!

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值