Pytorch 模型保存与加载总结

7 篇文章 0 订阅
3 篇文章 0 订阅

Pytorch 模型加载不匹配的问题

pytorch模型加载与保存

pytorch的模型的加载与保存,知乎文章写的清楚明白。也可以参考官方文档
总体来说就是两种:

  1. 直接保存模型总体。这样加载直接整个模型加载进来就好了。
torch.save(model,"yourpath.pth")# 一般保存为pth后缀
model = torch.load("yourpath.pth")
  1. 保存总体的问题就是太大了,而且不能更改。因此可以选择:只保存模型参数到OrderedDict(这个类没有仔细研究,不过见名知意,就是有顺序的字典)。加载时要先初始化模型,然后把这个参数导进去。因为是按照参数名称导入,只要参数名称一样就可以了。这样你就可以修改部分网络结构,而把没变的部分的参数导进去了。
torch.save(model.state_dict(),"yourpath.pth")
#load
model = ModelClass()##初始化你的网络
model_state_dict = torch.load("yourpath.pth")
model.load_state_dict(model_state_dict)

至于模型参数更新的具体步骤,就参考知乎的文章吧。

模型参数不匹配的问题

需要注意的是,在训练模型的时候,往往需要部署到GPU中:

model = nn.DataParallel(model)
model = model.cuda()

这样后,模型就不是原来模型的类型了,而是并行化后的模型:
Dataparallel
相应的参数字典也变了:
state_dict
所以如果并行化后的参数要加载到并行化的,没并行化的要加载到没并行化的。
那么并行化后,如何保存没并行化的结果呢1

torch.save(model.module.state_dict(), 'file_name.pt').

  1. pytorch:Missing keys & unexpected keys in state_dict when loading self trained model@ptrblck ↩︎

  • 2
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值