Pytorch搭建简单神经网络(三)——快速搭建、保存与提取

在之前的两篇文章中分别介绍了如何用pytorch搭建简单神经网络用于回归与分类。但是如何快速搭建一个简单的神经网络而不是定义一个类再去调用,以及我们定义了一个网络并训练好,该如何在日后去调用这个网络去实现相应的功能。

1、其他的相关代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable

x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(3)+0.1*torch.randn(x.size())

x , y =(Variable(x),Variable(y))

2、之前用类定义的网络

class Net(nn.Module):
    def __init__(self,n_input,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden1 = nn.Linear(n_input,n_hidden)
        self.hidden2 = nn.Linear(n_hidden,n_hidden)
        self.predict = nn.Linear(n_hidden,n_output)
    def forward(self,input):
        out = self.hidden1(input)
        out = F.relu(out)
        out = self.hidden2(out)
        out = F.relu(out)
        out =self.predict(out)

        return out

输出网络可以看出定义的网络如下

net = Net(1,20,1)
print(net)

在这里插入图片描述

一、快速搭建

接下来用快速搭建的方式来定义网络

需要使用到 torch.nn.Sequential(* args),一个时序容器。Modules会以他们传入的顺序被添加到容器中。当然,也可以传入一个OrderedDict[1]

 net1 = torch.nn.Sequential(
    nn.Linear(1,20),
    torch.nn.ReLU(),
    nn.Linear(20,20),
    torch.nn.ReLU(),
    nn.Linear(20,1)
)

输出net后可以看到

print(net1)

在这里插入图片描述
可以看出和上面搭建的网络的类型是一样的,自然通过训练后的功能是一样的,关于训练的 代码直接贴在这里,具体的不多赘述,可以参考

optimizer = torch.optim.SGD(net.parameters(),lr = 0.05)
loss_func = torch.nn.MSELoss()

plt.ion()
plt.show()

for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction,y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if t%5 ==0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 'Loss = %.4f' % loss.data, fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.05)

那么对于第一部分中已经训练好的网络,需要做的就是对训练好的网络进行保存,保存的方式有两种,一种是直接对网络的所有都进行保存[2],一种对网络中的参数进行保存,保存优化选项默认字典[3]

二、网络的保存

在训练后的网络中直接进行两种不同的保存方式

torch.save(net,'net.pkl')    #保存所有的网络参数
torch.save(net.state_dict(),'net_parameter.pkl')    #保存优化选项默认字典,不保存网络结构

运行后在当前目录生成指定pkl文件
在这里插入图片描述

三、网络的提取

1、提取整个网络的方法

直接调用torch.load来提取整个网络

torch.load从磁盘文件中读取一个通过torch.save()保存的对象。torch.load()可通过参数map_location动态地进行内存重映射,使其能从不动设备中读取文件。一般调用时,需两个参数: storagelocation tag. 返回不同地址中的storage,或着返回None (此时地址可以通过默认方法进行解析). 如果这个参数是字典的话,意味着其是从文件的地址标记到当前系统的地址标记的映射。[4]

net1 = torch.load('net.pkl')

2、提取网络中的参数的方法

对于提取网络中的参数的方式,必须先完整的建立和需要提取的网络一样的结构的网络,再去提取参数进而恢复网络

net2 = torch.nn.Sequential(
    nn.Linear(1,20),
    torch.nn.ReLU(),
    nn.Linear(20,20),
    torch.nn.ReLU(),
    nn.Linear(20,1)
)
net2.load_state_dict(torch.load('net_parameter.pkl'))

以上两种方法恢复的网络是一样的,有的朋友肯定会问,既然都一样的,为什么我们不选择直接恢复而是选择先建立一样的网络再去恢复参数。因为在大型神经网络中,网络的结构很复杂网络的参数也很发杂,所以直接保存整个网络会占用很大的磁盘资源,就本次实验的一个例子就可以看出,保存网络参数和保存网络结构对磁盘的占用是完全不同的,所以在大型神经网络中更倾向于用保存参数的方式去保存真个网络。
在这里插入图片描述
接下里咱们来看看整个网络恢复后的功能吧

prediction1 = net1(x)
prediction2 = net2(x)

#可视化的部分
plt.figure(1,figsize=(10,3))
plt.subplot(121)
plt.title('net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction1.data.numpy(), 'r-', lw=5)
# plt.show()

plt.figure(1,figsize=(10,3))
plt.subplot(122)
plt.title('net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction2.data.numpy(), 'r-', lw=5)
plt.show()

在这里插入图片描述
附整个程序的所有代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable

x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(3)+0.1*torch.randn(x.size())

x , y =(Variable(x),Variable(y))

'''
class Net(nn.Module):
    def __init__(self,n_input,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden1 = nn.Linear(n_input,n_hidden)
        self.hidden2 = nn.Linear(n_hidden,n_hidden)
        self.predict = nn.Linear(n_hidden,n_output)
    def forward(self,input):
        out = self.hidden1(input)
        out = F.relu(out)
        out = self.hidden2(out)
        out = F.relu(out)
        out =self.predict(out)

        return out
'''

'''
net = torch.nn.Sequential(
    nn.Linear(1,20),
    torch.nn.ReLU(),
    nn.Linear(20,20),
    torch.nn.ReLU(),
    nn.Linear(20,1)
)

# net = Net(1,20,1)
print(net)
'''
'''
optimizer = torch.optim.SGD(net.parameters(),lr = 0.05)
loss_func = torch.nn.MSELoss()

plt.ion()
plt.show()

for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction,y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if t%5 ==0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 'Loss = %.4f' % loss.data, fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.05)

torch.save(net,'net.pkl')
torch.save(net.state_dict(),'net_parameter.pkl')
'''

net1 = torch.load('net.pkl')
net2 = torch.nn.Sequential(
    nn.Linear(1,20),
    torch.nn.ReLU(),
    nn.Linear(20,20),
    torch.nn.ReLU(),
    nn.Linear(20,1)
)
net2.load_state_dict(torch.load('net_parameter.pkl'))

prediction1 = net1(x)
prediction2 = net2(x)

plt.figure(1,figsize=(10,3))
plt.subplot(121)
plt.title('net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction1.data.numpy(), 'r-', lw=5)
# plt.show()

plt.figure(1,figsize=(10,3))
plt.subplot(122)
plt.title('net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction2.data.numpy(), 'r-', lw=5)
plt.show()
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值