pytorch中模型导入问题

pytorch模型导入问题

1、RuntimeError: Error(s) in loading state_dict for DataParallel:


这里说明:训练模型的测试加载模型使用的环境不一样

解决方法:

1、在load_state()函数中加上False

model.load_state(checkpoint,False)

  从属性state_dic里复制到这个模块和他的后代,如果strict为True,state_dic的keys必须完全与这个模块的方法返回的keys相同,如果为False则不需要匹配。
  这个方法虽然可以将模型导入,但是不能保证导入的模型参数能完全匹配模型,在做测试的时候发现导入的参数只是随机的一部分,只能识别单一分类。
2、修改模型参数的keys使其与这个模型返回的keys一致

from collections import OrderedDict

checkpoint=torch.load('参数路径')
state_dict = checkpoint 
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if 'module' not in k:
        k = 'module.' + k
    else:
        continue
    new_state_dict[k]=v
model.load_state_dict(new_state_dict)

这个方法可以保证导入的模型参数跟模型完全一致。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值