【Pytorch 学习笔记(四)】:训练分类器

本文介绍了使用PyTorch训练CIFAR10图像分类器的详细步骤,包括数据加载和标准化、定义卷积神经网络、设置损失函数和优化器、训练与测试网络,以及GPU加速的相关内容。
摘要由CSDN通过智能技术生成

训练分类器

到目前为止,我们已经了解了如何定义NN,计算loss以及更新网络的权重。那我们接下来再来看一下数据的处理

data

当我们需要处理图像、文本、或音/视频文件时,通常可以找到一些python包来载入数据到numpy数组中。然后我们可以将这个数组再转化为torch.*Tensor

  • 对图片来说,可以用Pillow,OpenCV
  • 对音频来说,可以用scipy,librosa
  • 对文本来说,可以用python或Cython自带的载入,或者是NLTKSpaCy

特别地,对于视觉方面,可以使用torchvision,其中包括了对Imagenet、CIFAR10、MNIST等常用数据集的数据加载器(data loaders),包括对图片数据变形的操作。即torchvision.datasetstorch.utils.data.DataLoader

在这个教程中,我们将使用CIFAR10数据集。它有如下的分类:airplane、automobile、bird、 cat、 deer、 dog、 frog、horse、 ship、 truck 等。在CIFAR-10里面的图片数据大小是3x32x32,即三通道彩色图,图片大小是32x32像素。

training an image classifier

接下来会逐步做如下操作,其实也就是训练分类器的步骤:

  • 通过torchvision加载CIFAR10里面的训练和测试数据集,并对数据进行标准化处理
  • 定义卷积神经网络
  • 定义损失函数
  • 利用训练数据训练网络
  • 利用测试数据测试网络

1. Loading and normalizing CIFAR10

使用torchvision可以很方便地加载CIFAR10

import torch
import torchvision 
import torchvision.transforms as transforms

torchvision数据集加载完的输出为范围在[0,1]之间的PILImage图片。我们需要将其标准化为范围为[-1,1]之间的张量。下面的Normalize,param第一项是mean序列,此处为3项,因为有三个channel。第二项是std序列,同样有3项。同时注意此处的转换是out-of-place,它不会改变原有输出。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

输出:

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值