问题描述
对mobileNetv2+faster_RCNN进行训练之后,将权重保存。但是在预测时,导入权重出现错误。
RuntimeError: Error(s) in loading state_dict for rpn:
size mismatch for cls_logits.weight: copying a param with shape torch.Size([256,256,1,1]) from checkpoint, the shape in current model is torch.Size([3,256,1,1]).
size mismatch for cls_logits.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for bbox_pred.weight: copying a param with shape torch.Size([1280,1280,1,1]) from checkpoint, the shape in current model is torch.Size([12,1280,1,1]).
size mismatch for bbox_pred.bias: copying a param with shape torch.Size([1280]) from checkpoint, the shape in current model is torch.Size([11]).
解决办法
根据报错信息,报错的层为 rpn,
针对这个问题,就将
weights_path = "./save_weights/mobile-model-14.pth"
assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
weights_dict = torch.load(weights_path, map_location='cpu')
# print(len(weights_dict))
weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
model_dict = model.state_dict()
model_dict = model.state_dict()
# 解决报错问题,将rpn部分的参数冻结,不进行赋值。
pretrained_dict = {k: v for k, v in weights_dict.items() if (k in model_dict and 'rpn' not in k)}
model.load_state_dict(weights_dict,False)
model.to(device)
注:
但笔者在后续的过程,将该行代码注释后运行,竟然没有报错。反倒是进行冻结后,预测准确率低,时常抽风。