一、步骤
1.创建数据
import torch
import torch.nn
import matplotlib.pyplot as plt
from torch.autograd import Variable
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y=x**2+0.2*torch.rand(x.size())
x=Variable(x)
y=Variable(y)
2.搭建网络
def save():
#快速搭建网络
net=torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
loss_func=torch.nn.MSELoss()
#训练网络
for i in range(100):
prediction=net(x)
loss=loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#可视化
plt.figure(1,figsize=(10,3))
plt.subplot(131)
plt.title('Ne1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),color='red',lw='5')
3.保存网络的两种方法:
#保存网络的两种方法
torch.save(net,'net.pkl') #保存整个网络
torch.save(net.state_dict(),'net_params.pkl') #只保存网络中的参数。速度款,占内存少
4.提取网络(两种不同的提取方法):
- 提取整个神经网络
#提取整个神经网络
def restore_net():
#提取整个神经网络,网络大的时候可能会提取的比较慢
net=torch.load('net.pkl')
prediction=net(x)
#可视化
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), color='red', lw='5')
- 提取神经网络的参数
#提取网络参数
def resotre_params():
#先新建一个网络
net=torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
#将提取出来的已经训练好的参数赋给网络net
net.load_state_dict(torch.load('net_params.pkl'))
prediction=net(x)
#可视化
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), color='red', lw='5')
二、运行及结果显示
#两种方法保存网络
save()
#提取整个网络
restore_net()
#提取网络的参数,复制到新的网络
resotre_params()
plt.show()