'''
Author: 365JHWZGo
Description: 5.保存和获取神经网络【两种方法】
Date: 2021-10-23 12:26:12
FilePath: \pytorch\pytorch\day06-2.py
LastEditTime: 2021-10-23 17:36:48
LastEditors: 365JHWZGo
'''
#导包
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
#创造数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(2)+ 0.2*torch.rand(x.size())
x, y = Variable(x, requires_grad=True), Variable(y, requires_grad=True)
# x, y = Variable(x), Variable(y)
def save():
#创造神经层
net1 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
optimizer = torch.optim.SGD(net1.parameters(),lr = 0.5)
loss_func = torch.nn.MSELoss()
for i in range(100):
prediction = net1(x)
loss = loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
plt.figure(1,figsize=(10,3))
plt.subplot(131)
plt.title('net1',fontdict={'color':'red','fontsize':20})
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),lw=5,c='red')
torch.save(net1,'net1.pkl')
torch.save(net1.state_dict(),'net1_state_dict.pkl')
#方法一:保存整个神经网络
def restore_net():
net2 = torch.load('net1.pkl')
prediction = net2(x)
plt.subplot(132)
plt.title('net2',fontdict={'color':'red','fontsize':20})
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),lw=5,c='red')
#方法二:保存神经网络当中一些参数
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
net3.load_state_dict(torch.load('net1_state_dict.pkl'))
prediction = net3(x)
plt.subplot(133)
plt.title('net3',fontdict={'color':'red','fontsize':20})
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),lw=5,c='red')
plt.show()
save()
restore_net()
restore_params()
关于其中是否要写Variable的情况分析
什么都不写
x, y = Variable(x), Variable(y)
x, y = Variable(x, requires_grad=True), Variable(y, requires_grad=True)
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)