RuntimeError: Error(s) in loading state_dict for AutoEncoderModel......

 RuntimeError: Error(s) in loading state_dict for AutoEncoderModel:         Missing key(s) in state_dict: "net.pred.0.weight", "net.pred.0.bias", "net.pred.2.weight", "net.pred.2.bias", "net.pred.4.weight", "net.pred.4.bias".          Unexpected key(s) in state_dict: "net.flow.cnn.0.weight", "net.flow.cnn.0.bias", "net.flow.cnn.2.weight", "net.flow.cnn.2.bias", "net.flow.cnn.4.weight", "net.flow.cnn.4.bias".

        在训练时可能会出现以上错误:原因是由于加载错误的模型参数导致的。

        一般而言,我们在测试时,使用的模型model,要加载相应的模型参数信息,即checkpoint。

        假设在训练时定义了两个模型:model1和model2。训练完model1之后保存了模型参数,命名为checkpoint_1;然后在model_1的基础上修改了模型,在内部添加了一些新的参数(或者更改了参数的名称),命名为model_2,训练之后保存了模型参数,命名为checkpoint_2。

        那么在训练时:

  1.  如果使用的模型是model_1,而加载的模型参数是checkpoint_2,这时可能会出现Unexpected key(s) in state_dict: ...这类错误。因为checkpoint_2内保存的模型参数是model_2的参数,而model_2是在model_1的基础上添加了一些参数,我们把模型参数checkpoint_2加载到模型model_1上,就会多出一些不能匹配的参数。
  2. 如果使用的模型是model_2,而加载的模型参数是checkpoint_1,这时可能会出现Missing key(s) in state_dict: ...这类错误。因为checkpoint_1内保存的模型参数是model_1的参数,而model_1要比model_2内部少一些参数,我们把模型参数checkpoint_1加载到模型model_2上,就会出现缺少参数这类错误。
  3. 如果是更改了model中参数的名称,可能会同时出现上述两种错误。

 


pytorch 预训练model加载问题”Unexpected key(s) in state_dict:“

预训练模型使用:

  model = ResNet50()
  model = nn.DataParallel(model).cuda()

加载预训练模型是使用:


model = ResNet50().cuda()
 
checkpoint = torch.load(args.net_cache)
 
model.load_state_dict(checkpoint['state_dict'])//


报错:RuntimeError: Error(s) in loading state_dict for Resnet50   Unexpected key(s) in state_dict:““

说明加载模型和预训练模型环境不一致,本来修改如下(后经权重输出对比发现这种方式加载的model与原model权重不一致):

model = ResNet50().cuda()
 
checkpoint = torch.load(args.net_cache)
 
model.load_state_dict(checkpoint['state_dict'],False)//
model.load_state_dict(state_dict, strict=True)

解决1:由于用DataParallel训练的模型数据并行方式的,key中会包含”module“关键字,加载时直接用:

model = ResNet50().cuda()
model = nn.DataParallel(model)
 
checkpoint = torch.load(args.net_cache)
 
model.load_state_dict(checkpoint['state_dict'])//


解决2: 去掉DataParallel 预训练model中的module(可行权重值一致):

model = ResNet50().cuda()
checkpoint = torch.load(args.net_cache)
model.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()})


 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值