pytorch 三维点分类_PyTorch 实战:使用卷积神经网络对照片进行分类

我们接下来需要用CIFAR-10数据集进行分类,步骤如下:使用torchvision 加载并预处理CIFAR-10数据集

定义网络

定义损失函数和优化器

训练网络并更新网络参数

测试网络

注意:文章末尾含有项目jupyter notebook实战教程下载可供大家课后实战操作

一、CIFAR-10数据加载及预处理

CIFAR-10 是一个常用的彩色图片数据集,它有 10 个类别,分别是 airplane、automobile、bird、cat、deer、dog、frog、horse、ship和 truck。每张图片都是 3*32*32 ,也就是 三通道彩色图片,分辨率 32*32。

import torchvision as tv

import torchvision.transforms as transforms

from torchvision.transforms import ToPILImage

import torch as t

#可以把Tensor转化为Image,方便可视化

show = ToPILImage()

#先伪造一个图片的Tensor,用ToPILImage显示

fake_img = t.randn(3, 32, 32)

#显示图片

show(fake_img)

1

2

3

4

5

6

7

8

9

10

11

12

13importtorchvisionastv

importtorchvision.transformsastransforms

fromtorchvision.transformsimportToPILImage

importtorchast

#可以把Tensor转化为Image,方便可视化

show=ToPILImage()

#先伪造一个图片的Tensor,用ToPILImage显示

fake_img=t.randn(3,32,32)

#显示图片

show(fake_img)

第一次运行torchvision会自动下载CIFAR-10数据集,大约163M。这里我将数据直接放到项目 data文件夹 中。

cifar_dataset = tv.datasets.CIFAR10(root=\'data\',

train=True,

download=True

)

imgdata, label = cifar_dataset[90]

print(\'label: \', label)

print(\'imgdata的类型:\',type(imgdata))

imgdata

1

2

3

4

5

6

7

8

9cifar_dataset=tv.datasets.CIFAR10(root=\'data\',

train=True,

download=True

)

imgdata,label=cifar_dataset[90]

print(\'label:\',label)

print(\'imgdata的类型:\',type(imgdata))

imgdata

运行结果

Files already downloaded and verified

label: 2

imgdata的类型:

1

2

3Filesalreadydownloadedandverified

label:2

imgdata的类型:

注意,数据集中的照片数据是以 PIL.Image.Image类 形式存储的,在我们加载数据时,要注意将其转化为 Tensor类。

def dataloader(train):

transformer = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=(0.5, 0.5, 0.5),

std = (0.5, 0.5, 0.5))

])

cifar_dataset = tv.datasets.CIFAR10(root=\'data\', #下载的数据集所在的位置

train=train, #是否为训练集。

download=True, #设置为True,不用再重新下载数据

transform=transformer

)

loader = t.utils.data.DataLoader(

cifar_dataset,

batch_size=4,

shuffle=True, #打乱顺序

num_workers=2 #worker数为2

)

return loader

classes=(\'plane\', \'car\', \'bird\', \'cat\', \'deer\', \'dog\', \'frog\', \'horse\', \'ship\', \'truck\')

#训练集和测试集的加载器

trainloader = dataloader(train=True)

testloader = dataloader(train=False)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28defdataloader(train):

transformer=transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=(0.5,0.5,0.5),

std=(0.5,0.5,0.5))

])

cifar_dataset=tv.datasets.CIFAR10(root=\'data\',#下载的数据集所在的位置

train=train,#是否为训练集。

download=True,#设置为True,不用再重新下载数据

transform=transformer

)

loader=t.utils.data.DataLoader(

cifar_dataset,

batch_size=4,

shuffle=True,#打乱顺序

num_workers=2#worker数为2

)

returnloader

classes=(\'plane\',\'car\',\'bird\',\'cat\',\'deer\',\'dog\',\'frog\',\'horse\',\'ship\',\'truck\')

#训练集和测试集的加载器

trainloader=dataloader(train=True)

testloader=dataloader(train=False)

运行结果

Files already downloaded and verified

Files already downloaded and verified

1

2Filesalreadydownloadedandverified

Filesalreadydownloadedandverified

DataLoader是一个可迭代的对象,它将dataset返回的每一条数据样本拼接成一个batch,并提供多线程加速优化和数据打乱等操作。当程序对 cirfar_dataset 的所有数据遍历完一遍, 对Dataloader也完成了一次迭代。

dataiter = iter(trainloader)

#返回四张照片及其label

images, labels = dataiter.next()

#打印多张照片

show(tv.utils.make_gri

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值