1. torch版本问题
1.1 torch1.5版本及以下
import torch
## 模型参数保存
torch.save(model.state_dict(),'checkpoint/xxx.pth')
## 模型参数加载
args = torch.load('checkpoint/xxx.pth') # 参数读取
model_state_dict = model.state_dict()
for key in args:
if key in model_state_dict:
model_state_dict[key] = args[key]
model.load_state_dict(model_state_dict) # 参数加载
1.2 torch1.6版本
注意: 1.6之后pytorch默认使用zip文件格式来保存权重文件,导致这些权重文件无法直接被1.5及以下的pytorch加载。
因此保存参数时,应该将其转换为非zip格式,即_use_new_zipfile_serialization=False
import torch
torch.save(state_dict, "xxx.pth", _use_new_zipfile_serialization=False)
2. 加载参数,但模型准确率变了
原因: 网络结构中有BN层或Dropout层等
训练阶段,BN层获取的批次数据属性(均值、方差)会被记录下来,用于对测试数据的标准化;
Dropout层,在训练的阶段会有一些神经元权重被置零,但是在测试阶段,这些神经元又被重新使用
解决方法: 在保存模型之前,需要把模型进行eval,固定当下的模型参数,用于接下来的模型预测。
import torch
model.eval() # 很重要
checkpoint=model.state_dict()
torch.save(state_dict, "xxx.pth", _use_new_zipfile_serialization=False)
3. 加载参数时key值多个module
3.1 问题
加载之前保存的模型参数,期望获得的key值为 feature…,但得到的key值为 module.features…
3.2 原因
由模型保存过程中导致的,模型应该是在DataParallel模式下面,也就是采用了多GPU训练模型,然后直接保存的。
3.2 解决方案
- 去掉不需要的key值"module".
import torch
args = torch.load('checkpoint/xxx.pth') # 参数读取
model_state_dict = model.state_dict()
for key in args:
if key[7:] in model_state_dict:
model_state_dict[key[7:]] = args[key]
model.load_state_dict(model_state_dict)
- 加载模型之后,接着将模型DataParallel
import torch
args = torch.load('checkpoint/xxx.pth') # 参数读取
model = nn.DataParallel(model) #模型并行化,这个过程会将key值加一个module
model_state_dict = model.state_dict()
for key in args:
if key in model_state_dict:
model_state_dict[key] = args[key]
model.load_state_dict(model_state_dict)