一、序列化与反序列化
训练好的模型存储在内存中,内存中的数据不具备长久性存储的功能,所以,我们需要将模型从内存搬到硬盘中进行长久的存储,以备将来使用。这就是模型的保存与加载,也即序列化与反序列化。序列化指的是将内存中的对象保存到硬盘当中,以二进制序列的形式存储下来。反序列化指的是将硬盘当中的二进制序列反序列化的存储到内存中,得到对象,这样,我们就可以在内存中使用这个模型。序列化与反序列的目的是将数据、模型长久的保存。
二、PyTorch中的模型保存与加载方式
模型的保存与加载分别有两种模式:保存整个模型与保存模型参数。
1.模型保存
- torch.save
模型保存(序列化)
主要参数:- obj:对象(模型、张量、parameters、dict等)
- f:输出路径(硬盘当中的路径,用于保存)
模式一:保存整个Module
torch.save(net,path)
模式二:保存模型参数(官方推荐)
state_dict=net.state_dict()
torch.save(state_dict,path)
下面我们通过代码来简单观察模型保存两种模式的不同
import torch
import numpy as np
import torch.nn as nn
class LeNet2(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
def initialize(self):
for p in self.parameters():
p.data.fill_(20191104)
net = LeNet2(classes=2019)
# "模拟训练"
# 查看第一个卷积层的第一个卷积核的参数
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"
# 保存整个模型
torch.save(net, path_model)
# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
在当前文件目录下,就会生成model.pkl
与model_state_dict.pkl
。
2. 模型加载
- torch.load
模型加载(反序列化)
主要参数:- f:文件路径
- map_location:指定存放位置,cpu or gpu
模式一:加载整个Module
torch.load(net,path)
模式二:加载模型参数(官方推荐)
state_dict_load = torch.load(path_state_dict)
# 初始化模型
net_new = LeNet2(classes=2019)
# 模型加载参数
net_new.load_state_dict(state_dict_load)
下面我们通过代码来简单观察模型加载两种模式的不同
import torch
import numpy as np
import torch.nn as nn
class LeNet2(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
def initialize(self):
for p in self.parameters():
p.data.fill_(20191104)
# ================================== load net ===========================
path_model = "./model.pkl"
net_load = torch.load(path_model)
# 打印模型结构
print(net_load)
# ================================== load state_dict ===========================
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
print(state_dict_load.keys())
# ================================== update state_dict ===========================
net_new = LeNet2(classes=2019)
print("加载前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加载后: ", net_new.features[0].weight[0, ...])
LeNet2(
(features): Sequential(
(0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(</