PyTorch深度学习60分钟闪电战:04 训练一个分类器

本系列是PyTorch官网Tutorial Deep Learning with PyTorch: A 60 Minute Blitz 的翻译和总结。

  1. PyTorch概览
  2. Autograd - 自动微分
  3. 神经网络
  4. 训练一个分类器

下载本文的Jupyter NoteBook文件:60min_01_PyTorch Overview.ipynb


到目前,你已经知道了如何定义神经网络、计算损失、更新网络权重。

那么,数据呢?

数据

通常,当你处理图像,文本,音频或视频数据时,可以使用标准python包将数据加载到numpy数组中。然后便可以将数组转换为torch.*Tensor

  • 对于图像,Pillow,OpenCV等软件包很有用
  • 对于音频,可以使用scipy和librosa等包
  • 对于文本,基于Python或Cython的加载、NLTK和SpaCy都很有用

特别地,对于图像领域,我们创建了一个叫做torchvision的包,这个包包含了支持常见数据集(如Imagenet, CIFAR10, MNIST等)的数据加载器以及图像的数据转换器,即torchvision.datasetstorch.utils.data.DataLoader

这提供了极大的便利,并且避免了编写样板代码。

在这个教程中,我们将会使用CIFAR10数据集。它具有以下类别:“飞机”,“汽车”,“鸟”,“猫”,“鹿”,“狗”,“青蛙”,“马”,“船”,“卡车”。CIFAR-10中的图像尺寸为3×32×32,即尺寸为32x32像素的3通道彩色图像。

训练图像分类器

我们将按顺序执行以下步骤:

  1. 使用以下命令加载和标准化CIFAR10训练和测试数据集 torchvision
  2. 定义卷积神经网络
  3. 定义损失函数
  4. 根据训练数据训练网络
  5. 在测试数据上测试网络

1. 加载CIFAR10并将其规则化

torchvision的帮助下,加载CIFAR10非常容易。

import torch
import torchvision
import torchvision.transforms as transforms

torchvision数据集的输出是在[0,1]范围的PILImage。我们将它们转换成范围是[-1,1]的张量。

Note:

如果在Windows上运行时你遇到了一个BrokenPipeError,请尝试将torch.utils.data.DataLoader()num_worker设置为0。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog',
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值