import torch
import matplotlib.pyplot as plt
torch.manual_seed(1) # reproducible
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net1, 'net.pkl') # entire net
torch.save(net1.state_dict(), 'net_params.pkl') # parameters
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(1, 3, 1) # 一行三列三个图,绘制第一张图
plt.title('Net1')
plt.scatter(x.data, y.data)
plt.plot(x.data, prediction.data, 'r-', lw=5)
def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(1, 3, 2) # 一行三列三个图,绘制第二张图
plt.title('Net1')
plt.scatter(x.data, y.data)
plt.plot(x.data, prediction.data, 'r-', lw=5)
def restore_params(): # 推荐
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(1, 3, 3) # 一行三列三个图,绘制第三张图
plt.title('Net1')
plt.scatter(x.data, y.data)
plt.plot(x.data, prediction.data, 'r-', lw=5)
# save net1
save()
# restore entire net
restore_net()
# restore only the net parameters
restore_params()
plt.show()
首先我们来看一个语句torch.manual_seed(1)
,torch.manual_seed
在pytorch文档中只有一句话的解释,即为:设置用于生成随机数的种子。由于我们在这里需要用到两种方法,一种是保存和提取整个网络,一种是用来保存和提取参数,所以我们必须要保证这两种方法的随机默认参数是相同的,所以要进行这条语句的设定。
此处我们使用的样例是回归问题,我们建立了fake data(即自己定义的特征点)。首先在save函数中利用Sequential容器搭建一个神经网络,并且分别在当前目录下保存pkl文件,以便之后的提取使用。在名为restore_net和restore_params的函数中对两个文件中的内容进行了提取(此处要注意两个函数中的语句差异)。最后由下图输出结果可以看出,这三个图是相同的,说明提取成功。