神经网络的保存和提取
如果想要保存训练到当前状态的神经网络,为了第二天继续训练或者是提取,可以现将当前的状态保存下来
保存可以保存整个神经网络,也可以保存参数
提取,可以直接提取整个神经网络,也可以提取参数,构造一个一模一样的神经网络,直接把参数用于里面
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
# fake data
x = torch.unsqueeze(torch.linspace(-1,1,100),dim = 1)
y = x.pow(2) + 0.2*torch.rand(x.size())
x,y = Variable(x,requires_grad = False),Variable(y,requires_grad = False)
def save():
#快速搭建法搭建一个神经网络
net1 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
optimizer = torch.optim.SGD(net1.parameters(),lr = 0.5)
loss_func = torch.nn.MSELoss()
#训练一百步
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net1,'net.pkl') #储存整个神经网络
torch.save(net1.state_dict(),'net_param.pkl') #保存的是神经网络的parameters
#画图
plt.figure(1,figsize = (10,3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw = 5)
def restore_net():
net2 = torch.load('net.pkl')
#就可以直接提取整个神经网络包括参数了
prediction = net2(x)
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw = 5)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
#这种只提取参数的方法,需要重新构造一个核原来相同的神经网络
net3.load_state_dict(torch.load('net_param.pkl')) #听说这种方法快一些
prediction = net3(x)
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw = 5)
plt.show()
save()
restore_net()
restore_params()