Pytorch模型的保存及加载

深度学习模型保存模型参数的方法有两种:
1.保存整个网络(模型结构+模型参数):
# 保存整个模型和参数
torch.save(model_object, 'convit_tiny.pth')  
    
# 对应的加载模型代码为
model = torch.load('convit_tiny.pth')
print(model)

此时print的是整个网络的模型结构;
在这里插入图片描述
若要加载模型的参数:

model = torch.load('convit_tiny.pth')
args = model.state_dict()
print(args)

此时输出的是模型的训练参数:
在这里插入图片描述

2.直接保存网络的模型参数:
# 将my_resnet模型储存为my_resnet.pth,此时保存的仅仅是模型的参数
torch.save(model.state_dict(), "convit.pth")
# 直接加载参数
args = torch.load("convit.pth")
# 若要加载模型则先需要初始化之前所定义的网络
new_model = Net()
# 再使用load_state_dict方法将权重加载进网络
# 注意:model.state_dict()其实返回的是一个OrderDict,存储了网络结构的名字和对应的参数;而这里是导入参数因此用的是model.load_state_dict()而不是model.state_dict()
new_model.load_state_dict(torch.load('convit.pth'))
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值