Pytorch 保存加载模型时的坑
Pytorch 保存加载模型时的坑
在说Pytorch保存加载模型时的坑之前,先介绍一下pytorch对训练好的模型如何进行保存和加载。
方法1:保存模型的参数和结构信息
保存:
model=MobileNetV2(n_class=2)#加载模型
############进行训练##########
model = torch.nn.DataParallel(model, device_ids=[int(i) for i in args.gpus.strip().split(',')])#用多gpus 训练×××关键
############进行训练##########
torch.save(model, os.path.join(args.save_path, "epoch_" + str(epoch) + ".pth.tar"))#保存模型
恢复:
model=torch.load(args.load_path)#
这种方法会出现一个问题:当利用pytorch 1.0.0 保存好了模型后,加载时利用pytorch1.1.0 进行load() 时回报错,所以官方推荐使用第二种方法进行加载
方法二:官方推荐的方法,只保存和恢复模型中的参数
一个完整的例子:
迁移学习加载模型(此时 checkpoint 字典只有 state_dict ):
model=MobileNetV2(n_class=2)#加载模型结构
model_dict = model.state_dict()#获取模型参数(未加载保存的模型参数 )
if args.resume:#模型路径
if os.path.isfile(args.resume):
print(("=> loading checkpoint '{}'".format(args.resume)))
checkpoint = torch.load(args.resume)#获取模型参数
#因为我修改网络模型进行迁移学习,这一步是在checkpoint里获取没有修改的模型参数state_dict
state_dict = {k: v for k, v in checkpoint.items() if k in model_dict.keys()}
model_dict.update(state_dict)#更新已经保存的参数至model_dict
model.load_state_dict(model_dict)#加载模型参数
else:
print(("=> no checkpoint found at '{}'".format(args.resume)))
保存:–这里有坑
torch.save({"epoch":epoch, #一共训练的epoch
"model_state_dict":model.module.state_dict(), #保存模型参数×××××这里埋个坑××××
'epoch_acc': epoch_acc, #一共训练的epoch
"optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar"))
再加载:
print("start loading cls model")
model=MobileNetV2(n_class=2)
if os.path.isfile(args.load_path):
state_dict=torch.load(args.load_path)
print(state_dict['epoch'])#获取保存的参数 对应key值的参数
print(state_dict['epoch_acc'])
params=state_dict["model_state_dict"]
for param_tensor in params:#打印参数信息
print(param_tensor,"\t",params[param_tensor].size())
model.load_state_dict(params)
print("load cls model successfully")
填坑
这段保存模型参数的代码
torch.save({"epoch":epoch, #一共训练的epoch
"model_state_dict":model.module.state_dict(), #保存模型参数
'epoch_acc': epoch_acc, #一共训练的epoch
"optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar"))
torch.save({"epoch":epoch, #一共训练的epoch
"model_state_dict":model.state_dict(), #保存模型参数
'epoch_acc': epoch_acc, #一共训练的epoch
"optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar"))
与这段的不同在于model.module.state_dict()与model.state_dict()的区别
现在来打印一下
model=MobileNetV2(n_class=2)#加载模型结构
model_dict = model.state_dict()#获取模型参数(未加载保存的模型参数 )
model_dict----------model.module.state_dict()---------model.state_dict()三者参数的对应的名称(这里只打印几个)
model_dict:
features.0.0.weight torch.Size([32, 3, 3, 3])
features.0.1.weight torch.Size([32])
features.0.1.bias torch.Size([32])
features.0.1.running_mean torch.Size([32])
features.0.1.running_var torch.Size([32])
features.0.1.num_batches_tracked torch.Size([])
model.module.state_dict():
features.0.0.weight torch.Size([32, 3, 3, 3])
features.0.1.weight torch.Size([32])
features.0.1.bias torch.Size([32])
features.0.1.running_mean torch.Size([32])
features.0.1.running_var torch.Size([32])
features.0.1.num_batches_tracked torch.Size([])
model.state_dict():
module.features.0.0.weight torch.Size([32, 3, 3, 3])
module.features.0.1.weight torch.Size([32])
module.features.0.1.bias torch.Size([32])
module.features.0.1.running_mean torch.Size([32])
module.features.0.1.running_var torch.Size([32])
module.features.0.1.num_batches_tracked torch.Size([])
用多gpus进行训练后直接用model.state_dict()进行保存的模型,每个层参数的名称前面会加上module,这时候再用单卡 gpu model_dict加载model.state_dict()参数时会出现名称不匹配的情况。
因此保存模型时注意使用model.module.state_dict():
总结
1.多gpus训练 用model.state_dict() 保存前面会加上网络参数名称前会加上 module
2.单gpus加载模型,需要去掉网络参数名称前加上的module
两种方法:
(1) 用model.module.state_dict()保存
(2) 去掉网络参数名称前会加上的module再加载模型
3.推荐多gpus训练使用model.module.state_dict()保存,然后单gpu加载,
此时如果还需要多gpu训练可以在加载模型参数后使用torch.nn.DataParallel进行训练
还有另外的思路可参考 @[参考这里][参考这里]多gpu训练(https://blog.csdn.net/CV_YOU/article/details/86670188)(https://blog.csdn.net/qq_32998593/article/details/89343507)