python进阶教程:PyTorch快速搭建神经网络及其保存提取方法详解

@本文来源于公众号:csdn2299,喜欢可以关注公众号 程序员学府
有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解
一、PyTorch快速搭建神经网络方法

先看实验代码:

import torch 
import torch.nn.functional as F 
  
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
  
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
  
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
  
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
  
''''' 
方法1: 
 Net ( 
 (hidden):
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
卷积神经网络(Convolutional Neural Network,CNN)是一种前馈神经网络,主要应用于图像识别、语音识别等领域。在pytorch中,可以使用torch.nn模块来构建CNN。 下面以一个图像分类的例子来详细介绍CNN的应用及实现。 ## 数据集 我们使用的是CIFAR-10数据集,它包含10个类别的60000张32x32彩色图片。其中50000张用于训练,10000张用于测试。每个类别的训练集和测试集都有5000张图片。 ## 数据预处理 首先,我们需要对图像行预处理,将其转换为tensor,并行标准化处理。 ```python import torch import torchvision import torchvision.transforms as transforms # 数据预处理 transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 加载数据集 trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) ``` 上述代码中,我们定义了一个transform,它将图像转换为tensor,并行标准化处理。接着,我们使用torchvision加载CIFAR-10数据集,并定义一个DataLoader来对数据行批处理。 ## 定义CNN模型 我们定义一个简单的CNN模型,包括2个卷积层和3个全连接层。 ```python import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() ``` 上述代码中,我们定义了一个Net类,继承自nn.Module。在构造函数中,我们定义了2个卷积层(分别包含6个和16个卷积核),3个全连接层。在forward函数中,我们先通过卷积层和池化层行特征提取,然后将特征展开成一维向量,再通过全连接层行分类。 ## 定义损失函数和优化器 我们使用交叉熵损失函数和随机梯度下降优化器。 ```python import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) ``` ## 训练网络 我们使用训练集行训练,每次迭代都对网络参数行优化。 ```python for epoch in range(2): # 循环遍历数据集多次 running_loss = 0.0 for i, data in enumerate(trainloader, 0): # 获取输入数据 inputs, labels = data # 梯度清零 optimizer.zero_grad() # 前向传播,计算损失 outputs = net(inputs) loss = criterion(outputs, labels) # 反向传播,更新网络参数 loss.backward() optimizer.step() # 记录损失值 running_loss += loss.item() if i % 2000 == 1999: # 每2000个批次打印一次平均损失值 print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 print('Finished Training') ``` ## 测试网络 我们使用测试集行测试,并计算网络的准确率。 ```python correct = 0 total = 0 with torch.no_grad(): for data in testloader: # 获取输入数据 images, labels = data # 前向传播,输出预测结果 outputs = net(images) _, predicted = torch.max(outputs.data, 1) # 统计准确率 total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: %d %%' % ( 100 * correct / total)) ``` 上述代码中,我们使用torch.no_grad()来关闭梯度计算,这样可以减少内存的占用。在循环中,我们通过torch.max函数找到每个样本预测结果的最大值,并与标签行比较,统计准确率。 ## 总结 本文介绍了如何使用pytorch构建CNN模型,并对CIFAR-10数据集行图像分类。通过本文的学习,你可以了解到CNN的基本原理及实现方法

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值