torch的模型保存和加载 各种细节各种坑 尤其是多GPU训练会出现各种问题

#保存和加载整个模型
torch.save(model_object, ‘model.pth’)
model = torch.load(‘model.pth’)

#仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), ‘params.pth’)
model_object.load_state_dict(torch.load(‘params.pth’))

加载模型model.load_state_dict()问题,Unexpected key(s) in state_dict: “module.features…,Expected”.features直接原因是key值名字不对应

表明了加载过程中,期望获得的key值为feature…,而不是module.features…。这是由模型保存过程中导致的,模型应该是在DataParallel模式下面,也就是采用了多GPU训练模型,然后直接保存的

解决上面的问题的三个办法:
方法(1) 对load 的模型创建新的字典,去掉不需要的key值"modules"
#original saved file with DataParallel
state_dict = torch.load(‘checkpoint.pt’) # 模型可以保存为pth文件,也可以为pt文件。
#create new OrderedDict that does not contain module.
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove module.,表面从第7个key值字符取到最后一个字符,正好去掉了module.
new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
#load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。

方法(2) 直接用空白代替’module’
model.load_state_dict({k.replace(‘module.’,’’):v for k,v in torch.load(‘checkpoint.pth’).items()})
#相当于用’‘代替’module.’
#直接使得需要的键名等于期望的键名。

方法(3)最简单的办法 加载模型之后接着将模型DataParallel, 此时就可以load_state_stict
model = VGG()# 实例化自己的模型;
checkpoint = torch.load(‘checkpoint.pth’, map_location=‘cpu’) # 加载模型文件,pt, pth 文件都可以;
if torch.cuda.device_count() > 1:
# 如果有多个GPU,将模型并行化,用DataParallel来操作。这个过程会将key值加一个"module. ***"。
model = nn.DataParallel(model)
model.load_state_dict(checkpoint) # 接着就可以将模型参数load进模型。

  • 4
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值