文章目录
模型的保存与加载
1、序列化与反序列化
- 序列化:变量从内存中变成可存储或传输的过程称之为序列化
- 反序列化:把变量内容从序列化的对象重新读到内存里称之为反序列化
(1)PyTorch
的序列化——torch.save
主要参数:
obj
:对象f
:输出路径
(2)PyTorch
的反序列化——torch.load
f
:文件路径map_location
:指定存放位置(CPU或GPU)
2、模型保存与加载的两种方式
模型的保存方式有两种,一种是把整个模型的所有东西都保存下来,这种方法占用的磁盘空间较大,但是保存的信息较为全面;另一种是只保存模型的关键参数,其他的不保存,这种方法占用的磁盘空间较小,但是只保存了模型的关键参数信息,官方推荐第二种保存方式
- 方法一:保存整个模型——
torch.save(net, path)
- 方法二:保存模型参数——
state_dict = net.state_dict();torch.save(state_dict , path)
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"
# 保存整个模型
torch.save(net, path_model)
# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
加载模型
- 加载整个模型
path_model = "./model.pkl"
net_load = torch.load(path_model)
print(net_load)
# LeNet2(
# (features): Sequential(
# (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
# (1): ReLU()
# (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
# (4): ReLU()
# (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# )
# (classifier): Sequential(
# (0): Linear(in_features=400, out_features=120, bias=True)
# (1): ReLU()
# (2): Linear(in_features=120, out_features=84, bias=True)
# (3): ReLU()
# (4): Linear(in_features=84, out_features=2019, bias=True)
# )
# )
- 通过加载模型参数加载模型
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
print(state_dict_load.keys())
# odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.2.weight', 'classifier.2.bias', 'classifier.4.weight', 'classifier.4.bias'])
3、模型断点续训练
模型断点续训练主要用到模型训练意外中断的情况,模型断点续训练能够保证中断后的训练能够从中断点开始继续进行训练,因此需要对训练的状态进行保存,主要保存三方面的内容:
- 模型中的参数
- 优化器中的参数
- Epoch
checkpoint = {
"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch
}