MPNCOV分类模型问题汇总


前言

记录深度学习模型测试过程中的常见问题


一、遇到的问题

1.载入作者的预训练模型时Pytorch遇到权重不匹配的问题

具体报错如下:

RuntimeError: Error(s) in loading state_dict for mpncovresnet50:
	size mismatch for fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, the shape in current model is torch.Size([2, 2048]).
	size mismatch for fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([2]).

报错原因:因为下载作者的预训练模型中的全连接层是1000类别的,而我要训练的类别只有2类,所以会报不匹配的错误。
解决方案:从报错信息可以看出,是fc层的权重参数不匹配,只要不load 这一层的参数就可以了。
原始代码:在src/network/mpncovresnet.py文件里

def mpncovresnet50(pretrained=True, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = MPNCOVResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        print('load mpncovresnet50')
        checkpoint=torch.load('/data/weiqiang_peng/fast-MPN-COV/finetune/mpncovresnet50.pth')
        model.load_state_dict(checkpoint)
    return model

修改后代码:

def mpncovresnet50(pretrained=True, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = MPNCOVResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        print('load mpncovresnet50')
        # checkpoint=torch.load('/data/weiqiang_peng/fast-MPN-COV/finetune/mpncovresnet50.pth')
        # model.load_state_dict(checkpoint)
        # 加载pth预训练模型文件
        pretrained_dict = torch.load('/data/weiqiang_peng/fast-MPN-COV/finetune/mpncovresnet50.pth')
        model_dict = model.state_dict()
        # 重新制作预训练的权重,主要是减去参数不匹配的层,我这边报错的层名为“fc”
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)}
        # 更新权重
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    return model
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值