PyTorch笔记5-save和load神经网络

本系列笔记为莫烦PyTorch视频教程笔记 github源码

概要

用 PyTorch 训练好神经网络(NN)后,如何保存以便下次要用的时候直接提取使用即可,下面举栗

import torch
from torch.autograd import Variable
import torch.nn.functional as F      # activation function
import matplotlib.pyplot as plt

torch.manual_seed(1)   # torch seed

%matplotlib inline
# fake data
# unsqueeze set shape, otherwise (100,)
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # shape(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())
# only the variable can be trained
x, y = Variable(x), Variable(y)

# build NN
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)     # hidden layer
        self.prediction = torch.nn.Linear(n_hidden, n_output)  # output layer

    def forward(self, x):
        x = F.relu(self.hidden(x))    # activation func for hidden layer
        x = self.prediction(x)

        return x

net1 = Net(1, 10, 1)
print('net1: \n', net1)
net1: 
 Net (
  (hidden): Linear (1 -> 10)
  (prediction): Linear (10 -> 1)
)
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)

for epoch in range(100):
    prediction = net1(x)
    loss = loss_func1(prediction, y)

    optimizer.zero_grad()     # clear gradient for next train
    loss.backward()
    optimizer.step()

保存神经网络

下面用两种方式来保存
net1 为保存整个网络
net2 只保存网络中的参数(速度快,占内存少)

torch.save(net1, './NNPkl/net.pkl')
torch.save(net1.state_dict(), './NNPkl/net_parms.pkl')
/Users/yangjiahua/pytorch-test/lib/python3.6/site-packages/torch/serialization.py:147: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.
  "type " + obj.__name__ + ". It won't be checked "

提取神经网络

也有两种方式提取:提取整个网络以及提取网络参数
net2 为提取整个神经网络
net3 为提取神经网络参数,注意,该方式需要先建立一个跟所提取神经网络参数一样的网络架构,然后再赋予参数

net2 = torch.load('./NNPkl/net.pkl')
net3 = Net(1, 10, 1)                 # first build same NN as net1
net3.load_state_dict(torch.load('./NNPkl/net_parms.pkl'))

可视化比较

画图查看提取保存中的网络跟原来训练的是否一致

plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

plt.subplot(132)
plt.title('net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), net2(x).data.numpy(), 'r-', lw=5)

plt.subplot(133)
plt.title('net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), net3(x).data.numpy(), 'r-', lw=5)
[<matplotlib.lines.Line2D at 0x10d9479e8>]

这里写图片描述

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值