如何把数据集和 transform 组合在一起使用。
一、引入
每次导入数据集的时候,都要通过 torchvision ,也即先要 import torchvision
二、举例
可以在pytorch官网上找到各类数据集,例如:CIFAR10,附上官网地址:CIFAR10 — 火炬视觉主要文档 (pytorch.org)。该数据集基本数据如下
![](https://img-blog.csdnimg.cn/img_convert/e8c5aa6c10c826cee3fa26b783b90a79.png)
1、参数说明:
root:是指根目录,即数据集的位置;
train:是 True的时候,表示创建的数据集是训练集,为False表示创建了一个测试集;
Transform:表示对数据集进行的变化;
target_transform:对于 target(目标)进行改变;
download:如果是 true,从网上下载这个数据集,是 False,就不会下载;
数据集下载以及初步读取:
import torchvision
train_set = torchvision.datasets.CIFAR10(root='./dataset',train=True,download=True)
# 以上是训练数据集,下面是训练数据集
test_set = torchvision.datasets.CIFAR10(root='./dataset',train=False,download=True)
print(test_set[0])
![](https://img-blog.csdnimg.cn/img_convert/2e1e38af3b3b78c9416da71ce66ff6dd.png)
以上代码进行 debug,会出现代码相关信息,其中 ,class_to_idx 说的就是数据集内的标签(target),这里输出的是 test_set[0],所以对应 'airplane';classes 是分类。在知道分类的情况下,可以用其他方式进行输出:
print(test_set.classes)
# 输出:['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
同时,知道组成部分是(图片,标签)--->(img, target),那么引用方法可以如下所示:
img,target = test_set[0]
print(img) # 结果:<PIL.Image.Image image mode=RGB size=32x32 at 0x165BA56A288>
print(target) # 结果:3
print(test_set.classes[target]) # 结果:cat
# classes得到的是一个列表,通过对应的target标签,也就是列表的下标去获取对应的类别
img.show() # 输出的是一张图片
![](https://img-blog.csdnimg.cn/img_convert/ff96578c14facabe090588f80c374dbb.png)
因为图片只有32*32,所以很小
2、数据集和transform联动
最大的问题是图片格式问题,要把 PIL 转变成 tensor 格式,这就需要用transform 实现。
实例:
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=dataset_transform ,download=True)
# 以上是训练数据集,下面是训练数据集
test_set = torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=dataset_transform ,download=True)
writer = SummaryWriter('p10')
for i in range(10):
img, target = test_set[i]
writer.add_image('test_set',img, i) # 最后一个参数:步径,就是第0张图片是什么,第 i 张图片是什么
writer.close()
输出的就是十张图片,也是在 transform 上显示。