size mismatch for rpn.head.cls_logits.weight

问题描述

对mobileNetv2+faster_RCNN进行训练之后,将权重保存。但是在预测时,导入权重出现错误。

RuntimeError: Error(s) in loading state_dict for rpn:
	size mismatch for cls_logits.weight: copying a param with shape torch.Size([256,256,1,1]) from checkpoint, the shape in current model is torch.Size([3,256,1,1]).
	size mismatch for cls_logits.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([3]).
	size mismatch for bbox_pred.weight: copying a param with shape torch.Size([1280,1280,1,1]) from checkpoint, the shape in current model is torch.Size([12,1280,1,1]).
	size mismatch for bbox_pred.bias: copying a param with shape torch.Size([1280]) from checkpoint, the shape in current model is torch.Size([11]).

解决办法

根据报错信息,报错的层为 rpn,
在这里插入图片描述
针对这个问题,就将

weights_path = "./save_weights/mobile-model-14.pth"
    assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
    weights_dict = torch.load(weights_path, map_location='cpu')
    # print(len(weights_dict))
    weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
	
    model_dict = model.state_dict()
    model_dict = model.state_dict()
    # 解决报错问题,将rpn部分的参数冻结,不进行赋值。
    pretrained_dict = {k: v for k, v in weights_dict.items() if (k in model_dict and 'rpn' not in k)}
    model.load_state_dict(weights_dict,False)

    model.to(device)

注:

但笔者在后续的过程,将该行代码注释后运行,竟然没有报错。反倒是进行冻结后,预测准确率低,时常抽风。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值