pytorch模型加载
基本的情形:
一、单卡训练好的模型文件,推理阶段的加载(即加载模型文件和model定义的是一致的)
state_dict = torch.load('checkpoint.pth.tar')
Mymodel.load_state_dict(state_dict)
二、多gpu训练保存的模型,在单卡情况下加载
当出现以第一种的写法会有问题时,自然的就需要检查Mymodel和要加载的模型文件的差异
# 打印你自己定义的模型的键值对
params = Mymodel.state_dict() # 获得模型的原始状态以及参数, orderdict数据类型
for k, v in params.items():
print(k) # 只打印key值,不打印具体参数值
# 打印加载的checkpoint文件的键值对
state_dict = torch.load('checkpoint.pth.tar')
for k, v in checkpoint.items():
print(k)
然后你就会发现有些key值多了module
,这时候我们只需要去掉多余的名字使得两者的key值相同就可以正常加载了。主要有3种:
-
从key中的第7个字符开始取
from collections import OrderedDict state_dict = torch.load('checkpoint.pth.tar') new_state_dict = OrderedDict() for k,v in state_dict.items(): name = k[7:] # remove `module` new_state_dict[name] = v # load params Mymodel.load_state_dict(new_state_dict)
-
module 替换为空字符
name = k.replace('module.', '')
-
最简单的方法,加载模型之后,接着将模型DataParallel,此时就可以load_state_dict。
如果有多个GPU,将模型并行化,用DataParallel来操作。这个过程会将key值加一个module.
Mymodel = resnet18()# 实例化自己的模型; state_dict = torch.load('checkpoint.pt', map_location='cpu') if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model.load_state_dict(checkpoint) # 可以直接将模型参数load进模型。
三、加载模型文件与当前构建模型相同部分层的权重
主要有两种写法:
-
在写模型类时,在训练好的模型基础上搭建自己的模型。即先带预训练模型建model,然后再修改、删除或者添加模块;
featureExtract = resnet18(pretrained=True) # load weight self.featureEncoder = nn.Sequential(*list(featureExtract.children())[:-2])
-
整个模型类已经实现,调用后加载部分层参数
使用前最好打印检查下,因为设置了
strict=false
,如果只差了module
是不会报错的。#第一种方法: mymodelB = TheModelBClass(*args, **kwargs) # strict=False,设置为false,只保留键值相同的参数 mymodelB.load_state_dict('checkpoint.pt', strict=False) #第二种方法: # 加载模型 model_pretrained = torch.load('checkpoint.pt') # mymodel's state_dict, # 如: conv1.weight # conv1.bias mymodelB_dict = mymodelB.state_dict() # 将model_pretrained的建与自定义模型的建进行比较,剔除不同的 pretrained_dict = {k: v for k, v in model_pretrained.items() if k in mymodelB_dict} # 更新现有的model_dict mymodelB_dict.update(pretrained_dict) # 加载我们真正需要的state_dict mymodelB.load_state_dict(mymodelB_dict)