目录
1.保存和提取神经
(1)代码
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
torch.manual_seed(1) #reproducible
#fake data
x = torch.unsqueeze(torch.linspace(-1,1,100),dim = 1)#旧版使用
y = x.pow(2) + 0.2*torch.rand(x.size())
x,y = Variable(x,requires_grad=False),Variable(y,requires_grad = False)
def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1,10), #输入
torch.nn.ReLU(), #回归激励函数
torch.nn.Linear(10,1) #输出
)
optimizer = torch.optim.SGD(net1.parameters(),lr = 0.2)
loss_func = torch.nn.MSELoss()
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net1,'net.pkl') #entire net
torch.save(net1.state_dict(),'net_params.pkl') #parameters 保留状态
#plot result
plt.figure(1,figsize = (10,3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw =5)
def restore_net():#提取神经网络
net2 = torch.load('net.pkl')
prediction = net2(x)
#plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
def restore_params():
#net3 = net1
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10), # 输入
torch.nn.ReLU(), # 回归激励函数
torch.nn.Linear(10, 1) # 输出
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
#plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
#save net1
save()
#restore entire net
restore_net()
#restor only the net parameters
restore_params()
(2)运行结果
Process finished with exit code 0