pytorch笔记6--网络的保存和提取

一、步骤

1.创建数据

import torch
import torch.nn
import matplotlib.pyplot as plt
from torch.autograd import Variable
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y=x**2+0.2*torch.rand(x.size())
x=Variable(x)
y=Variable(y)

2.搭建网络

def save():
    #快速搭建网络
    net=torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )
    optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
    loss_func=torch.nn.MSELoss()
    #训练网络
    for i in range(100):
        prediction=net(x)
        loss=loss_func(prediction,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #可视化
    plt.figure(1,figsize=(10,3))
    plt.subplot(131)
    plt.title('Ne1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),color='red',lw='5')

3.保存网络的两种方法:

    #保存网络的两种方法
    torch.save(net,'net.pkl')        #保存整个网络
    torch.save(net.state_dict(),'net_params.pkl')  #只保存网络中的参数。速度款,占内存少

4.提取网络(两种不同的提取方法):

  • 提取整个神经网络
#提取整个神经网络
def restore_net():
    #提取整个神经网络,网络大的时候可能会提取的比较慢
    net=torch.load('net.pkl')
    prediction=net(x)

    #可视化
    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), color='red', lw='5')
  • 提取神经网络的参数
#提取网络参数
def resotre_params():
    #先新建一个网络
    net=torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )
    #将提取出来的已经训练好的参数赋给网络net
    net.load_state_dict(torch.load('net_params.pkl'))
    prediction=net(x)

    #可视化
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), color='red', lw='5')

二、运行及结果显示

#两种方法保存网络
save()
#提取整个网络
restore_net()
#提取网络的参数,复制到新的网络
resotre_params()
plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值