Pytorch_rapid_Neural_net&save_net&restore_net

快速搭建神经网络
函数库

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt

创建伪数据

n_data=torch.ones(100,2)
x0=torch.normal(2*n_data,1)
x1=torch.normal(-2*n_data,1)
y0=torch.zeros(100)
y1=torch.ones(100)
x=torch.cat((x0,x1),1).data.Type(FloatTensor)
y=torch.cat((y0,y1),).data.Type(LongTensor)
x,y=Variable(x),Variable(y)

创建神经网络

net=torch.nn.Sequential(	
	torch.nn.Linear(2,10)
	torch.nn.ReLU()
	torch.nn.Linear(10,2)
)
print(net)#看看生成的网络如何

输出的网络

Sequential(
  (0): Linear(in_features=2, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=2, bias=True)
)

储存神经网络
创建新的数据

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
torch.manual_seed(2)#reproducible&&运行这段代码会发现,每次得到的随机数是固定的

#假数据
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y=x.pow(2)+0.2*torch.rand(x.size())
x,y=Variable(x),Variable(y)

搭建神经网络

net=torch.nn.Sequential(
	torch.nn.Linear(1,10)
	torch.nn.ReLU()
	torch.nn.Linear(10,1)
)

优化神经网络

optimizer=torch.optim.SGD(net.parameters(),lr=0.5)
loss_func=torch.nn.MSELoss()

训练神经网络

for t in range(100):
	prediction=net(x)
	loss=loss_func(prediction,y)
	optimizer.zero_grad()
	loss.backward()
	optimizer.step()

储存神经网络

def save():
	#{把上面搭建并训练的网络掺进来}
	torch.save(net1,'net.pkl')
	torch.save(net1.state_dict(),'net_params.pkl')
	#画图
	plt.figure(1,figsize(10,3))
	plt.subplot(131)
	plt.title('Net1')
	plt.scatter(x.data.numpy(),y.data.numpy(),'r-',lw=5)

提取神经网络

def restore_net():
	net2=torch.load('Net.pkl')
	prediction=net2(x)
	plt.subplot(132)
	plt.title('Net2')
	plt.scatter(x.data.numpy(),y.data.numpy())
	plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

提取神经网络参数

def restore_params():
	net3=torch.nn.Sequential(
	torch.nn.Linear(1,10)
	torch.nn.ReLU()
	torch.nn.Linear(10,1)
)
	net3.load_state_dict(torch.load('net_params.pkl'))
	prediction=net3(x)
	plt.subplot(133)
	plt.title('Net3')
	plt.scatter(x.data.numpy(),y.data.numpy())
	plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
	plt.show()

主函数

save()
restore_net()
restore_params()

生成图片如图所示
在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值