在我们复现神经网络相关论文,或者自己做实验时通常会加入新模块或自己修改网络结构,这个时候如果加载网络原先的预训练权重会导致和自己模型的参数不匹配,但我们在训练自己的模型时希望能保留预训练权重,因此可以保留一部分匹配的模型权重并删除不匹配的,参考步骤如下。
1.首先检查原网络的预训练权重
pretrained_weight_path = "./convnext_tiny_1k_224_ema.pth"
state_dic = torch.load(pretrained_weight_path)
pretrained_keys = state_dic.keys()
for key in pretrained_keys:
print(key)
-
注意:网络的预训练权重可能只会显示一个model,这可能是原先训练的时候使用了多GPU等原因(具体目前不太了解),如果输出的是model,这在打印的时候在加个访问字典中的key:modle
state_dic = torch.load(pretrained_weight_path)["model"]
2.加载网络的预训练权重
net = create_model(num_classes=6)
#这里creat_model是自定义的一个创建网络的函数,这个网络是你新加入其他模块后的网络
net_keys = net.state_dict().keys()
#这里把每一个预训练权重的key存起来
3.查找缺失的权重中的key
#现在的模型权重比预训练权重多
missing_keys = net_keys - pretrained_keys
for key in missing_keys:
print(key)
#预训练权重比现在的模型权重多
unexpected_keys = pretrained_keys - net_keys
print("\n模型中缺失的权重")
for key in unexpected_keys:
print(key)
4.判断并删除不匹配的模块
pretrained_weight_path = "./convnext_tiny_1k_224_ema.pth"
#加载预训练权重,返回类型是字典
pretrained_dict = torch.load(pretrained_weight_path)
net = create_model(num_classes=6)
#加载自定义网络模型权重
model_dict = net.state_dict()
#判断预训练权重模型和自定义网络的模型参数,如果key和对应shape都相同则取出,否则就去掉
pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict and(v.shape == model_dict[k].shape)}
#更新修改后的参数
model_dict.update(pretrained_dict)
#并重新让模型加载参数dict
net.load_state_dict(model_dict,strict=True)
- 也可以通过将
load_state_dict(pretrained_dict,strict=false)
中strict设置为false来忽略掉不匹配的模块,但这并不是绝对去除,只是忽略,个人理解不太推荐这样。strict=True
:这表示必须完全匹配模型的所有参数。如果model_dict
中缺少某些参数,或者存在额外的参数,将会抛出错误。这通常用于确保加载的权重与模型结构完全一致。- 优点:
- 精确控制:可以确保只加载你认为重要的权重,避免潜在问题。
- 透明性高:清楚哪些权重被加载,哪些被忽略,有助于调试。
- 优点:
strict=False
:允许部分匹配,即可以加载部分参数,忽略多余的或缺失的参数。这在模型结构略有变动时非常有用。- 缺点:
- 不明确:可能导致某些参数未被加载,潜在地影响模型性能。
- 调试困难:如果有重要参数未加载,可能不易发现问题。
- 缺点: