深度模型参数save & load 遇到的各种问题

1. torch版本问题

1.1 torch1.5版本及以下

import torch
## 模型参数保存
torch.save(model.state_dict(),'checkpoint/xxx.pth')

## 模型参数加载
args = torch.load('checkpoint/xxx.pth')    # 参数读取
model_state_dict = model.state_dict()
for key in args:
	if key in model_state_dict:
		model_state_dict[key] = args[key]   
model.load_state_dict(model_state_dict)    # 参数加载

1.2 torch1.6版本

注意: 1.6之后pytorch默认使用zip文件格式来保存权重文件,导致这些权重文件无法直接被1.5及以下的pytorch加载。
因此保存参数时,应该将其转换为非zip格式,即_use_new_zipfile_serialization=False

import torch
torch.save(state_dict, "xxx.pth", _use_new_zipfile_serialization=False)

2. 加载参数,但模型准确率变了

原因: 网络结构中有BN层或Dropout层等
训练阶段,BN层获取的批次数据属性(均值、方差)会被记录下来,用于对测试数据的标准化;
Dropout层,在训练的阶段会有一些神经元权重被置零,但是在测试阶段,这些神经元又被重新使用
解决方法: 在保存模型之前,需要把模型进行eval,固定当下的模型参数,用于接下来的模型预测。

import torch
model.eval()   # 很重要
checkpoint=model.state_dict()
torch.save(state_dict, "xxx.pth", _use_new_zipfile_serialization=False)

3. 加载参数时key值多个module

3.1 问题

加载之前保存的模型参数,期望获得的key值为 feature…,但得到的key值为 module.features…

3.2 原因

由模型保存过程中导致的,模型应该是在DataParallel模式下面,也就是采用了多GPU训练模型,然后直接保存的。

3.2 解决方案

  • 去掉不需要的key值"module".
import torch
args = torch.load('checkpoint/xxx.pth')    # 参数读取
model_state_dict = model.state_dict()
for key in args:
	if key[7:] in model_state_dict:
		model_state_dict[key[7:]] = args[key]
model.load_state_dict(model_state_dict)
  • 加载模型之后,接着将模型DataParallel
import torch
args = torch.load('checkpoint/xxx.pth')    # 参数读取
model = nn.DataParallel(model)  #模型并行化,这个过程会将key值加一个module
model_state_dict = model.state_dict()
for key in args:
	if key in model_state_dict:
		model_state_dict[key] = args[key]
model.load_state_dict(model_state_dict)
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值