strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur
问题
我们知道通过
model.load_state_dict(state_dict, strict=False)
可以暂且忽略掉模型和参数文件中不匹配的参数,先将正常匹配的参数从文件中载入模型。
笔者在使用时遇到了这样一个报错:
RuntimeError: Error(s) in loading state_dict for ViT_Aes:
size mismatch for mlp_head.1.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).
size mismatch for mlp_head.1.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).
一开始笔者很奇怪,我已经写明strict=False
了,不匹配参数的不管就是了,为什么还要给我报错。
原因及解决方案
经过笔者仔细打印模型的键和文件中的键进行比对,发现是这样的:strict=False
可以保证模型中的键与文件中的键不匹配时暂且跳过不管,但是一旦模型中的键和文件中的键匹配上了,PyTorch就会尝试帮我们加载参数,就必须要求参数的尺寸相同,所以会有上述报错。
比如在我们需要将某个预训练的模型的最后的全连接层的输出的类别数替换为我们自己的数据集的类别数,再进行微调,有时会遇到上述情况。这时,我们知道全连接层的参数形状会是不匹配,比如我们加载 ImageNet 1K 1000分类的预训练模型,它的最后一层全连接的输出维度是1000,但如果我们自己的数据集是10分类,我们需要将最后一层全链接的输出维度改为10。但是由于键名相同,所以PyTorch还是尝试给我们加载,这时1000和10维度不匹配,就会导致报错。
解决方案就是我们将 .pth 模型文件读入后,将其中我们不需要的层(通常是最后的全连接层)的参数pop掉即可。
以 ViT 为例子,假设我们有一个 ViT 模型,并有一个参数文件 vit-in1k.pth
,它里面存储着 ViT 模型在 ImageNet-1K 1000分类数据集上训练的参数,而我们要在自己的10分类数据集上微调这个模型。
model = ViT(num_classes=10)
ckpt = torch.load('vit-in1k.pth', map_location='cpu')
msg = model.load_state_dict(ckpt, strict=False)
print(msg)
直接这样加载会出错,就是上面的错误:
size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).
size mismatch for head.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).
我们将最后 pth 文件加载进来之后(即 ckpt
) 中全连接层的参数直接pop掉,至于需要pop掉哪些键名,就是上面报错信息中提到了的,在这里就是 head.weight
和 head.bias
ckpt.pop('head.weight')
ckpt.pop('head.bias')
之后在运行,会发现我们打印的 msg
显示:
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
即缺失了head.weight
和 head.bias
这两个参数,这是正常的,因为在自己的数据集上微调时,我们本就不需要这两个参数,并且已经将它们从模型文件字典 ckpt
中pop掉了。现在,模型全连接之前的层(通常即所谓的特征提取层)的参数已经正常加载了,接下来可以在自己的数据集上进行微调。
因为反正我们也不用这些参数,就直接把这个键值对从字典中pop掉,以免 PyTorch 在帮我们加载时试图加载这些维度不匹配,我们也不需要的参数。