1.前言
训练好了一个模型, 我们当然想要保存它, 留到下次要用的时候直接提取直接用,下面我将来讲如何存储训练好的模型参数
2.torch.save(保存模型)
首先,先搭建一个神经网络
import torch
from torch import nn
import matplotlib.pyplot as plt
torch.manual_seed(11) # 使每次得到的随机数是固定的。但是如果不加上torch.manual_seed这个函数调用的话,打印出来的随机数每次都不一样
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # [100] -> [100,1]
y = x.pow(2) + 0.5*torch.rand(x.size()) # y的形状与x一样
def make_and_save_model():
network = torch.nn.Sequential(
torch.nn.Linear(1, 8),
torch.nn.ReLU(),
torch.nn.Linear(8, 1)
)
optimizer = torch.optim.SGD(network.parameters(), lr=0.3) #优化器
criterion = torch.nn.MSELoss() #损失函数
# 训练
for i in range(200):
prediction = network(x) #数据放入模型后得到预测值
loss = criterion(prediction, y) #计算预测值与真实值之间的误差
optimizer.zero_grad() #清空梯度
loss.backward() #误差反向传播
optimizer.step() #更新参数
torch.save(network, 'network.pth') # 保存整个网络
torch.save(network.state_dict(), 'network_params.pth') # 只保存网络中的参数
plt.figure(1, figsize = (10,3))
plt.subplot(131)
plt.title('network')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)
plt.pause(1)
3.torch.load整个网络
这种方式将会提取整个神经网络, 网络大的时候可能会比较慢.
def load_whole_model():
network_whole = torch.load('network.pth')
prediction = network_whole(x)
plt.figure(1, figsize = (10,3))
plt.subplot(132)
plt.title('network_whole')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)
plt.pause(1)
4.torch.load网络参数(只提取参数)
这种方式将会提取所有的参数, 然后再放到你的新建网络中
def load_only_params():
network_params = torch.nn.Sequential(
torch.nn.Linear(1, 8),
torch.nn.ReLU(),
torch.nn.Linear(8, 1)
)
network_params.load_state_dict(torch.load('network_params.pth'))
prediction = network_params(x)
plt.figure(1, figsize = (10,3))
plt.subplot(133)
plt.title('network_params')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)
5.调用三个函数
会看到加载后的模型画出的图是一样的,说明模型的参数正确加载了。
make_and_save_model()
load_whole_model()
load_only_params()