我们接下来需要用CIFAR-10数据集进行分类,步骤如下:使用torchvision 加载并预处理CIFAR-10数据集
定义网络
定义损失函数和优化器
训练网络并更新网络参数
测试网络
注意:文章末尾含有项目jupyter notebook实战教程下载可供大家课后实战操作
一、CIFAR-10数据加载及预处理
CIFAR-10 是一个常用的彩色图片数据集,它有 10 个类别,分别是 airplane、automobile、bird、cat、deer、dog、frog、horse、ship和 truck。每张图片都是 3*32*32 ,也就是 三通道彩色图片,分辨率 32*32。
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch as t
#可以把Tensor转化为Image,方便可视化
show = ToPILImage()
#先伪造一个图片的Tensor,用ToPILImage显示
fake_img = t.randn(3, 32, 32)
#显示图片
show(fake_img)
1
2
3
4
5
6
7
8
9
10
11
12
13importtorchvisionastv
importtorchvision.transformsastransforms
fromtorchvision.transformsimportToPILImage
importtorchast
#可以把Tensor转化为Image,方便可视化
show=ToPILImage()
#先伪造一个图片的Tensor,用ToPILImage显示
fake_img=t.randn(3,32,32)
#显示图片
show(fake_img)
第一次运行torchvision会自动下载CIFAR-10数据集,大约163M。这里我将数据直接放到项目 data文件夹 中。
cifar_dataset = tv.datasets.CIFAR10(root=\'data\',
train=True,
download=True
)
imgdata, label = cifar_dataset[90]
print(\'label: \', label)
print(\'imgdata的类型:\',type(imgdata))
imgdata
1
2
3
4
5
6
7
8
9cifar_dataset=tv.datasets.CIFAR10(root=\'data\',
train=True,
download=True
)
imgdata,label=cifar_dataset[90]
print(\'label:\',label)
print(\'imgdata的类型:\',type(imgdata))
imgdata
运行结果
Files already downloaded and verified
label: 2
imgdata的类型:
1
2
3Filesalreadydownloadedandverified
label:2
imgdata的类型:
注意,数据集中的照片数据是以 PIL.Image.Image类 形式存储的,在我们加载数据时,要注意将其转化为 Tensor类。
def dataloader(train):
transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5),
std = (0.5, 0.5, 0.5))
])
cifar_dataset = tv.datasets.CIFAR10(root=\'data\', #下载的数据集所在的位置
train=train, #是否为训练集。
download=True, #设置为True,不用再重新下载数据
transform=transformer
)
loader = t.utils.data.DataLoader(
cifar_dataset,
batch_size=4,
shuffle=True, #打乱顺序
num_workers=2 #worker数为2
)
return loader
classes=(\'plane\', \'car\', \'bird\', \'cat\', \'deer\', \'dog\', \'frog\', \'horse\', \'ship\', \'truck\')
#训练集和测试集的加载器
trainloader = dataloader(train=True)
testloader = dataloader(train=False)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28defdataloader(train):
transformer=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,0.5,0.5),
std=(0.5,0.5,0.5))
])
cifar_dataset=tv.datasets.CIFAR10(root=\'data\',#下载的数据集所在的位置
train=train,#是否为训练集。
download=True,#设置为True,不用再重新下载数据
transform=transformer
)
loader=t.utils.data.DataLoader(
cifar_dataset,
batch_size=4,
shuffle=True,#打乱顺序
num_workers=2#worker数为2
)
returnloader
classes=(\'plane\',\'car\',\'bird\',\'cat\',\'deer\',\'dog\',\'frog\',\'horse\',\'ship\',\'truck\')
#训练集和测试集的加载器
trainloader=dataloader(train=True)
testloader=dataloader(train=False)
运行结果
Files already downloaded and verified
Files already downloaded and verified
1
2Filesalreadydownloadedandverified
Filesalreadydownloadedandverified
DataLoader是一个可迭代的对象,它将dataset返回的每一条数据样本拼接成一个batch,并提供多线程加速优化和数据打乱等操作。当程序对 cirfar_dataset 的所有数据遍历完一遍, 对Dataloader也完成了一次迭代。
dataiter = iter(trainloader)
#返回四张照片及其label
images, labels = dataiter.next()
#打印多张照片
show(tv.utils.make_gri