#保存和加载整个模型
torch.save(model_object, ‘model.pth’)
model = torch.load(‘model.pth’)
#仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), ‘params.pth’)
model_object.load_state_dict(torch.load(‘params.pth’))
加载模型model.load_state_dict()问题,Unexpected key(s) in state_dict: “module.features…,Expected”.features直接原因是key值名字不对应
表明了加载过程中,期望获得的key值为feature…,而不是module.features…。这是由模型保存过程中导致的,模型应该是在DataParallel模式下面,也就是采用了多GPU训练模型,然后直接保存的
解决上面的问题的三个办法:
方法(1) 对load 的模型创建新的字典,去掉不需要的key值"modules"
#original saved file with DataParallel
state_dict = torch.load(‘checkpoint.pt’) # 模型可以保存为pth文件,也可以为pt文件。
#create new OrderedDict that does not contain module.
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove module.
,表面从第7个key值字符取到最后一个字符,正好去掉了module.
new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
#load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。
方法(2) 直接用空白代替’module’
model.load_state_dict({k.replace(‘module.’,’’):v for k,v in torch.load(‘checkpoint.pth’).items()})
#相当于用’‘代替’module.’
#直接使得需要的键名等于期望的键名。
方法(3)最简单的办法 加载模型之后接着将模型DataParallel, 此时就可以load_state_stict
model = VGG()# 实例化自己的模型;
checkpoint = torch.load(‘checkpoint.pth’, map_location=‘cpu’) # 加载模型文件,pt, pth 文件都可以;
if torch.cuda.device_count() > 1:
# 如果有多个GPU,将模型并行化,用DataParallel来操作。这个过程会将key值加一个"module. ***"。
model = nn.DataParallel(model)
model.load_state_dict(checkpoint) # 接着就可以将模型参数load进模型。