客官别急,三步教你从容的加载预训练网络参数
1.拿到网络后先查看网络模型和预训练模型,2.将网络模型和预训练模型的键调整成一样,并加载, 3.将两者的参数都打印一下,看是否加载成功。(对应以下三点)
1. 查看网络参数
pretrained_dict1 = torch.load(model_path1, map_location='cpu')['state_dict']#预训练文件后缀是.tar
pretrained_dict2 = torch.load(model_path2)#预训练文件后缀是.pth
#1.查看预训练网络参数
for key ,value in pretrained_dict1.items():#pretrained_dict1,pretrained_dict2就是上面的东西
count+=1
print(key)
print(count)
#2.查看model的网络参数
for key ,value in model.state_dict.items():
print(key,value)
2. 加载模型遇到的两大问题
1. 模型的键不匹配
以下两代码,解决了键不匹配问题,一个是删除键的某一部分,一是添加键的某一部分。
例:
下面的错误是因为模型的model.state_dict().items()的键是conv1.weight,预训练的键是module.conv1.weight,导致不匹配。所以下面的代码是让module. 去掉
1.删除键的头部
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict2.items()}
当然有时候自己model的键需要改进,如下
2.补齐键的头部
checkpoint={'module.'+k:v for k,v in pretrained_dict.items()}
2. 预训练模型和自己的model长度不一样
# 删除pretrained_dict.items()中model所没有的东西
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 只保留预训练模型中,自己建的model有的参数
model_dict.update(pretrained_dict) # 将预训练的值,更新到自己模型的dict中
model.load_state_dict(model_dict) # model加载dict中的数据,更新网络的初始值
3. 通过查看加载参数,看是否加载成功
for value1 ,value2 in zip(checkpoint.items(), model.state_dict().items()):
print(value1,value2)
如下所示,model的参数和预训练的参数是一样的
4. 案例
(这里处理的只是针对本人的model加载的情况,要想正确加载,还需遵守上面3步)
def load_param(self, model_path):#这里的self就是model
model_dict = self.state_dict()
pretrained_dict = torch.load(model_path)#这里model_path的后缀是.pth可直接读取
# pretrained_dict = {k.replace('module.', ''): v for k, v in
# pretrained_dict.items()} # 因为pretrained_dict得到module.conv1.weight,但是自己建的model无module,只是conv1.weight,所以改写下
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 只保留预训练模型中,自己建的model有的参数
model_dict.update(pretrained_dict) # 将预训练的值,更新到自己模型的dict中
self.load_state_dict(model_dict) # model加载dict中的数据,更新网络的初始值