6-pytorch - 网络的保存和提取

本文介绍了如何在PyTorch中保存和提取训练好的神经网络,包括保存整个网络和单独保存参数。作者建议保存参数以提高效率,并提供了使用`torch.save`和`load_state_dict`的具体示例。
摘要由CSDN通过智能技术生成

前言

我们训练好的网络,怎么保存和提取呢?
总不可以一直不关闭电脑吧,训练到一半,想结束到明天再来训练,这就需要进行网络的保存和提取了。
本文以前面博客3-pytorch搭建一个简单的前馈全连接层网络(回归问题)的网络进行网络的保存和提取,建议先看完上面博客再来看本博客。

一、生成训练数据

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# 生成数据(fake data)
x = torch.linspace(-1,1,100).reshape(-1,1)
# 加上点噪声
y = x.pow(2) + 0.2*torch.rand(x.shape)

# 可视化一下数据
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

输出:
在这里插入图片描述

二、网络保存

def save():
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(), 
        torch.nn.Linear(10,1)
    )
    optimizer = torch.optim.SGD(net1.parameters(),lr=0.5)
    loss_func = torch.nn.MSELoss()
    
    for t in range(100):
        prediction = net1(x)
        loss = loss_func(prediction,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # 下面介绍两种不同的保存方法,方法二可能运行速度要快点
    # 保存整个网络的所有
    torch.save(net1, 'net.pkl')     
    # 保存好网络的参数
    torch.save(net1.state_dict(),'net_params.pkl')
    
    # plot result
    plt.figure(1,figsize=(10,3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

【注】:保存整个网络还是保存网络参数,个人建议仅保存参数,这个速度更快。

三、网络提取

def restore_net():
    net2 = torch.load('net.pkl')
    prediction = net2(x)
    
    # plot result
    plt.subplot(132)
    plt.title('Net1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
    
def restore_params():
    # 如果只是保留参数的情况,提取时需要再次定义相同网络才行
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)
    
    # plot result
    plt.subplot(133)
    plt.title('Net1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

四、对保存网络提取进行结果展示

save()
restore_net()
restore_params()

在这里插入图片描述

  • 9
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
DGCNN(Dynamic Graph Convolutional Neural Networks)是一种用于图分类的神经网络模型,通过动态图卷积实现对图结构的特征提取和分类。下面是使用DGCNN模型的PyTorch库的方法。 首先,安装所需的Python库和PyTorch。然后,从GitHub上下载并安装DGCNN-PyTorch库。可以使用以下命令: ``` git clone https://github.com/muhanzhang/pytorch_DGCNN.git cd pytorch_DGCNN pip install -r requirements.txt ``` 下载并准备数据集,将数据集存储在合适的路径下。接下来,根据实际的数据集和训练需要调整hyperparameters文件中的参数。 运行`main.py`文件进行训练和测试。可以使用以下命令: ``` python main.py ``` 在命令行中,可以设置一些参数,如数据集路径、模型保存路径、训练时的批次大小、迭代次数等。训练过程将根据设置的参数进行训练,并在测试集上评估模型性能。 另外,如果需要使用预训练的模型进行图分类任务,可以通过以下命令加载预训练模型并进行预测: ``` python from model import DGCNN import torch device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = DGCNN().to(device) model.load_state_dict(torch.load('saved_model.pth')) model.eval() # 假设有一个图样本data output = model(data) # 进行预测 ``` 以上是使用DGCNN-PyTorch库的基本方法。根据实际任务和数据集,您可能需要进行适当的调整和修改。希望对您有所帮助!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值