保存和读取神经网络的两种方法
直接保存整个训练结果
torch.save(net, 'net.pkl') # 保存整个网络
def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
保存训练结果中的参数
torch.save(net.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少)
def restore_params():
# 构建的神经网络结构必须一样
net3 = net = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# 提取参数节点
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
示例
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
# torch.manual_seed() 用于设计随机初始化种子的,要知道神经网络都需要初始化,那么如何初始化,以及怎么保证初始化每次都相同?这时候使用同样的随机初始化种子即可保证
torch.manual_seed(1) # reproducible
# 假数据
# squeeze的用法主要就是对数据的维度进行压缩或者解压。torch.squeeze(input, dim, out=None)
# input (Tensor) – 输入张量
# dim (int, optional) – 如果给定,则input只会在给定维度挤压
# out (Tensor, optional) – 输出张量
# torch.linspace(start, end, steps=100, out=None) → Tensor返回一个1维张量,包含在区间start和end上均匀间隔的step个点。输出张量的长度由steps决定。
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
# 均匀分布 torch.rand(*sizes, out=None) → Tensor返回一个张量,包含了从区间[0, 1)的均匀分布中抽取的一组随机数。张量的形状由参数sizes定义。
y = x.pow(2) + 0.2 * torch.rand(x.size())
def save():
# 建网络
net = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
# 训练
for t in range(100):
prediction = net(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 两种途径来保存
torch.save(net, 'net.pkl') # 保存整个网络
torch.save(net.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少)
# 绘制结果
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
def restore_params():
# 构建的神经网络结构必须一样
net3 = net = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# 提取参数节点
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
save()
restore_net()
restore_params()
结语
如果你发现文章有什么问题,欢迎留言指正。
如果你觉得这篇文章还可以,别忘记点个赞加个关注再走哦。
如果你不嫌弃,还可以关注微信公众号———梦码城(持续更新中)。
梦码在这里感激不尽!!