需求描述
在large-scale数据集上进行的预训练模型上进行具体任务的炼丹已经非常常见了,只不过我们可能的最后的任务不同。例如在ImageNet上的预训练模型最后一层是维度为1000的全连接层,但如果我们使用MiniImageNet,则类别只有100.因为导入预训练模型时,全连接层无法进行导入。直接使用model.load_state_dict()函数会出现如下问题
RuntimeError: Error(s) in loading state_dict for VisionTransformerDiffPruning:
size mismatch for head.weight: copying a param with shape torch.Size([1000, 384]) from checkpoint, the shape in current model is torch.Size([100, 384]).
size mismatch for head.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([100]).
解决方案
方案则是:model.load_state_dict()
提供了加载部分预训练模型的选择
#加载model,mymodel是自己定义好的模型
pretrainedmodel = models.resnet50(pretrained=True)
mymodel =Net(...)
#读取参数
pretrained_dict =pretrainedmodel.state_dict()
model_dict = mymodel.state_dict()
#将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)