问题描述
当用Pytorch训练好的ResNet模型权重进行推理时出现了如下报错:
RuntimeError: Error(s) in loading state_dict for ResNet:
size mismatch for fc.weight: copying a param with shape torch.Size([2, 512]) from checkpoint, the shape in current model is torch.Size([100, 512]).
size mismatch for fc.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([100]).
原因分析:
出现以上报错的原因主要是推理所用的权重的某一层参数和加载的模型需要的参数匹配不上,如报错显示,这里是fc.weight匹配不上,也就是模型的全连接层,本文使用的是ResNet,可以直接models目录下查看 resnet.py
模型定义。可以看到ResNet类中代码是最原始cifar100任务中定义的100类,num_classes=100
,而报错显示所用训练好的模型权重中全连接层定义的维度是 2
。
解决方案:
将models目录下报错的模型 resnet.py
中的类别改为跟权重模型一致的2, num_classes=2
如下图所示:
诸如此类的问题都有可能是加载的模型权重(训练时),跟推理时出现了变化,因此改代码时要做好记录。