模型保存/加载的四种方法
1.保存/加载状态字典(state_dict)
2.保存/加载整个模型(entire model)
3.保存/加载checkpoint信息
4.保存/加载多个模型到一个文件
注:详情请参阅 pytorch 官方文档 链接
"""模型保存与加载
方法1:保存/加载状态字典(state_dict)
该方法具有更大的灵活性,推荐使用
方法2:保存/加载整个模型(entire model)
方法3:保存/加载checkpoint
该方法以字典形式存储模型信息,推荐训练过程使用
方法4:保存/加载多个模型到一个文件
该方法可用于模型重用(即使用预训练模型)
主要针对像GAN/sequence-to-sequence model/an ensemble of models(一组模型) 这样包含多个torch.nn.Modules的模型
"""
import torch
import torch.nn as nn # 模型构建模块
import torch.optim as optim # 优化器模块
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
# 定义前馈
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义并初始化模型
net = Net()
# 定义优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#------------------------------------------
"""方法1:保存/加载模型(state_dict)"""
path = "./state_dict_model.pth" # 保存路径
# save
torch.save(net.state_dict(),path)
# load
model = Net()
model.load_state_dict(torch.load(path))
#-------------------------------------------
"""方法2:保存/加载整个模型"""
path = "./entire_model.pth" # 保存路径
# save
torch.save(net,path)
# load
model = torch.load(path)
#------------------------------------------
"""方法3:保存/加载checkpoint"""
# save
epoch = 5
loss = 0.4
path = "./checkpoint_model.pth" # 保存路径
torch.save({"epoch":epoch,
"model_state_dict":net.state_dict(),
"optim_state_dict":optimizer.state_dict(),
"loss":loss},
path)
# load
# 初始化模型及优化器
model = Net()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
# 加载checkpoint配置
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optim_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
# 注:加载checkpoint配置之前必须首先初始化model和optimizer
#------------------------------------------
"""方法4:保存/加载多个模型"""
# 定义模型并初始化
netA = Net()
netB = Net()
# 设置优化器
optimizerA = optim.SGD(netA.parameters(),lr=0.01,momentum=0.9)
optimizerB = optim.SGD(netB.parameters(),lr=0.001,momentum=0.95)
# save multiple models
path = "./model.pth" # 保存路径
torch.save({"modelA_state_dict":netA.state_dict(),
"modelB_state_dict":netB.state_dict(),
"optimA_state_dict":optimizerA.state_dict(),
"optimB_state_dict":optimizerB.state_dict()},
path)
# load mutiple models
# 定义模型并初始化
netA = Net()
netB = Net()
# 设置优化器
optimizerA = optim.SGD(netA.parameters(),lr=0.01,momentum=0.9)
optimizerB = optim.SGD(netB.parameters(),lr=0.001,momentum=0.95)
checkpoint = torch.load(path)
modelA.load_state_dict(checkpoint["modelA_state_dict"])
modelB.load_state_dict(checkpoint["modelB_state_dict"])
optimizerA.load_state_dict(checkpoint["optimA_state_dict"])
optimizerB.load_state_dict(checkpoint["optimB_state_dict"])
# evaluate model
modelA.eval()
modelB.eval()
# train model
modelA.train()
modelB.train()
# 注:加载模型配置时必须首先初始化模型和优化器,再加载模型配置信息
#------------------------------------------
# 注:如果模型用于评估/推断需要设置模型为eval()模式
# model.eval()主要是设置dropout/batch normlization为评估状态
model.eval()
# --or--
model.train()