完整报错:error :size mismatch for fution.weight: copying a param with shape torch.Size([28, 56, 1, 1]) from checkpoint, the shape in current model is torch.Size([28, 28, 1, 1])
这是由于在加载模型的预训练权重时,遇到了一个size不匹配的问题,预训练模型和当前模型之间存在结构差异。
def load_matching_weights(model, pretrained_weights):
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_weights.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
# 使用此函数加载匹配的权重
load_matching_weights(model, checkpoint)
这个子函数也可以去掉!