strict=False 但还是 size mismatch for xx.weight copying a param with shape [] from checkpoint错误的解决方法

目录

参考链接:

问题重现:

问题分析:

原因及解决方案


参考链接:

strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur_copying a param with shape torch.size([768]) from -CSDN博客

size mismatch for xx.weight错误的解决方法-CSDN博客

问题重现:

RuntimeError: Error(s) in loading state_dict for MMDataParallel:
        size mismatch for module.neck.def_convs.0.reppoints_pts_init_out.weight: copying a param with shape torch.Size([14, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([21, 64, 1, 1]).
        size mismatch for module.neck.def_convs.0.reppoints_pts_init_out.bias: copying a param with shape torch.Size([14]) from checkpoint, the shape in current model is torch.Size([21]).
        size mismatch for module.neck.def_convs.0.reppoints_pts_refine_out.weight: copying a param with shape torch.Size([14, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([21, 64, 1, 1]).
        size mismatch for module.neck.def_convs.0.reppoints_pts_refine_out.bias: copying a param with shape torch.Size([14]) from checkpoint, the shape in current model is torch.Size([21]).
        size mismatch for module.neck.def_convs.1.reppoints_pts_init_out.weight: copying a param with shape torch.Size([10, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([15, 64, 1, 1]).
        size mismatch for module.neck.def_convs.1.reppoints_pts_init_out.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([15]).
        size mismatch for module.neck.def_convs.1.reppoints_pts_refine_out.weight: copying a param with shape torch.Size([10, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([15, 64, 1, 1]).
        size mismatch for module.neck.def_convs.1.reppoints_pts_refine_out.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([15]).
        size mismatch for module.neck.def_convs.2.reppoints_pts_init_out.weight: copying a param with shape torch.Size([6, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([9, 64, 1, 1]).
        size mismatch for module.neck.def_convs.2.reppoints_pts_init_out.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([9]).
        size mismatch for module.neck.def_convs.2.reppoints_pts_refine_out.weight: copying a param with shape torch.Size([6, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([9, 64, 1, 1]).
        size mismatch for module.neck.def_convs.2.reppoints_pts_refine_out.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([9]).

问题分析:

在使用load_state_dict()时会报错,原因是你现有的模型和你权重文件的保存的模型结构不一样,可以通过

model.load_state_dict(state_dict, strict=False)

暂且忽略掉模型和参数文件中不匹配的参数,先将正常匹配的参数从文件中载入模型。

然而当笔者已经写明strict=False,重新运行代码还是会报同样的错误。这就很奇怪了,既然已经写明strict=False了,那运行过程中不匹配参数的不管就是了,为什么还要给我报错?

原因及解决方案

经过查阅资料之后,发现是这样的:

strict=False可以保证模型中的键与文件中的键不匹配时暂且跳过不管,但是一旦模型中的键和文件中的键匹配上了,PyTorch就会尝试帮我们加载参数,就必须要求参数的尺寸相同,所以会有上述报错。

比如在我们需要将某个预训练的模型的最后的全连接层的输出的类别数替换为我们自己的数据集的类别数,再进行微调,有时会遇到上述情况。这时,我们知道全连接层的参数形状会是不匹配,比如我们加载 ImageNet 1K 1000分类的预训练模型,它的最后一层全连接的输出维度是1000,但如果我们自己的数据集是10分类,我们需要将最后一层全链接的输出维度改为10。但是由于键名相同,所以PyTorch还是尝试给我们加载,这时1000和10维度不匹配,就会导致报错。

解决方案就是我们将 .pth 模型文件读入后,将其中我们不需要的层(通常是最后的全连接层)的参数pop掉即可

以 笔者自己的任务为例子,假设我们有一个 车道线检测 模型,并有一个参数文件 lanedetection.pth,它里面存储着 车道线检测模型在 CULane数据集上训练的参数,而我们要在CULane数据集上微调这个改进后的模型(添加了卷积模块)。

将最后 pth 文件加载进来之后(即 pretrained_model) 中报错层的参数直接pop掉,至于需要pop掉哪些键名,就是上面报错信息中提到了的,在这里就是 weight 和 bias

pretrained_model = torch.load(model_dir)
        
    pretrained_model['net'].pop("module.neck.def_convs.0.reppoints_pts_init_out.weight")
    pretrained_model['net'].pop("module.neck.def_convs.0.reppoints_pts_init_out.bias")
    pretrained_model['net'].pop("module.neck.def_convs.0.reppoints_pts_refine_out.weight")
    pretrained_model['net'].pop("module.neck.def_convs.0.reppoints_pts_refine_out.bias")
    
    pretrained_model['net'].pop("module.neck.def_convs.1.reppoints_pts_init_out.weight")
    pretrained_model['net'].pop("module.neck.def_convs.1.reppoints_pts_init_out.bias")
    pretrained_model['net'].pop("module.neck.def_convs.1.reppoints_pts_refine_out.weight")
    pretrained_model['net'].pop("module.neck.def_convs.1.reppoints_pts_refine_out.bias")
    
    pretrained_model['net'].pop("module.neck.def_convs.2.reppoints_pts_init_out.weight")
    pretrained_model['net'].pop("module.neck.def_convs.2.reppoints_pts_init_out.bias")
    pretrained_model['net'].pop("module.neck.def_convs.2.reppoints_pts_refine_out.weight")
    pretrained_model['net'].pop("module.neck.def_convs.2.reppoints_pts_refine_out.bias")
    
    
    
    # 使用神经网络模型 net 的 load_state_dict 函数来加载预训练模型的权重。strict=False 参数表示允许加载预训练模型中的权重,即使它们不完全匹配当前神经网络模型的结构。因此你可以将预训练模型的权重加载到与其不完全匹配的模型中
    net.load_state_dict(pretrained_model['net'], strict=False)

至此,模型就可以正常运行了。

即使缺失了weight 和 bias 这两个参数,这也是正常的,因为我们要对模型进行修改微调,本就不需要这两个参数,并且已经将它们从模型文件字典中pop掉了。现在,模型其他层的参数已经正常加载了,接下来可以微调自己的模型。

反正我们也不需要这些参数,就直接把这个键值对从字典中pop掉,以免 PyTorch 在帮我们加载时会出现加载这些维度不匹配的情况。

  • 31
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值