pytorch初学笔记(五):torchvision中dataset的最详细使用(以CIFAR10和MNIST为例)

目录

一、torchvision介绍

1. 作用与结构

2. torchvision中常用数据集

二、CIFAR10的介绍

1.  数据集简介

2. 使用该数据集的所需参数 

3. 数据集下载

3.1 pycharm在线下载(下载速度较快时) 

3.2 第三方下载

3.3 数据库的下载总结 

三、 CIFAR10的具体使用

1. 数据集对象的显示(PIL型)

2. 把数据集中的图片对象转换为tensor型

2.1 转换所需transform的定义

2.2 使用tensorboard进行图片显示

四、练习:MNIST数据集的下载和使用

1. 可能的报错和修改 

2. 代码实现

2.1 PIL对象实现

2.2 tensor对象实现

3. 运行结果 


一、torchvision介绍

1. 作用与结构

torchvision — Torchvision main documentation

torchvision是pytorch下的一个包,主要由计算机视觉中的流行数据集、模型体系结构和常见图像转换等模块组成。

 常用的包:

  • Transforming and augmenting images:进行图片变换等。
  • Models and pre-trained weights:提供一些预训练好的神经网络或权重参数等。
  • Dataset :提供常用的数据集。

2. torchvision中常用数据集

Datasets — Torchvision main documentation

 Datasets模块提供了需要常用的数据集以及其具体的使用方法,比如下图所示的图像分类中常用的CIFAR10数据集,图像检测中常用的COCO数据集等。

 

 下面具体说明如何对CIFAR10进行下载和使用。

二、CIFAR10的介绍

1.  数据集简介

CIFAR-10 and CIFAR-100 datasets (toronto.edu)

  •     CIFAR-10是一个更接近普适物体的彩色图像的小型数据集
  •     一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
  •      每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。

2. 使用该数据集的所需参数 

CIFAR10 — Torchvision main documentation

需要设定的5个参数:

1.   root(字符串型):把数据集下载到的位置路径。

2.   train(布尔型):是否把该数据集作为训练数据集使用。

  • True: 作为训练数据集创建
  • False:不作为训练数据集,作为测试数据集创建

3.   transform:图像需要进行的变换操作,一般使用compose把所需的transforms结合起来。

4.   target_transform:对于标签需要做的变换

5.   download(布尔型):是否下载数据集。

  • True:把数据集下载到root指定的对应位置;如果数据集以及进行过下载,则不会再一次下载
  • False:不下载数据集

3. 数据集下载

3.1 pycharm在线下载(下载速度较快时) 

    1. 导入torchvision包,然后依次创建训练数据集和测试数据集。

注意:训练数据集的train参数要设置为True,测试数据集的train设置为False

import torchvision
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,download=True)

    2. 点击运行,等待一段时间后显示下载成功 

    3. 观察项目包目录,可以发现自动创建了名为dataset3的文件夹,下载的解压文件和解压好的数据集都在其中。

3.2 第三方下载

    如果在pycharm中下载速度很慢的话,可以找到pycharm所用的下载链接,然后自己使用迅雷等下载软件进行快速下载。

如何找到下载链接?

  1. 把鼠标移动到想要下载的数据集名称上,然后Ctrl+C,进入该数据集的帮助文档。 

      2. 可以看到对应的下载文件名和下载链接。 

    3. 使用迅雷或者浏览器下载,然后把下载过后的压缩文件按照root中定义的路径创建文件夹,然后把文件放入文件夹中,注意,自己创建的文件夹一定要和root中定义的文件夹姓名相同才行,否则后期扫描不到该数据集

    4. 运行上面在线下载中定义的语句,可以发现程序不会再次下载数据集文件,而是会帮你解压好数据集。

3.3 数据库的下载总结 

无论是否需要在线下载数据集,都推荐把download参数值设为True。

因为程序可以帮你自动完成下载解压工作,就算自己下载过文件,也可以提供解压功能,因此更加方便。

三、 CIFAR10的具体使用

1. 数据集对象的显示(PIL型)

import torchvision
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,download=True)

#1. 查看数据集的图片
#输出所有类别
print(test_set.classes)
#输出数据集第一张图片的类型
print(test_set[0])
#输出图片的PIL型格式和标签
img,label = test_set[0]
print(label,test_set.classes[label])
img.show()

        1.  数据集所有类别的查看

        图片有十个类,对应的类别名称存储在dataset.classes列表中。

        2. 数据集中单个具体对象的查看

        想要输出数据集中具体的某一张图片,使用下标调用方式dataset[x]即可显示第x+1张图片;输出的对象类型为一个元组,里面第一项是PIL类型的图片,第二项是图片的标签。

        3. 数据集中图片对象和标签的定义

        可以使用  img,label = dataset[x] 的方式接收对象中的图片和label,然后可以用print进行对label的输出,也可以用 dataset. classes[label]的格式进行对该类别名称的显示。

        4. 数据集中图片的可视化

        使用img.show()方法进行图片的可视化显示。

输出结果如下: 

        打开的对应图片如下图所示,由于数据集中的图片较小,所以不清晰,但是可以看出来是一只小猫的图片。 

2. 把数据集中的图片对象转换为tensor型

2.1 转换所需transform的定义

        因为需要完成数据集中所有图片类型从PIL到tensor的转换,我们需要用到transforms工具,也需要设定数据集中的transform参数。

        我们在数据集定义的语句之前定义我们需要的transform,由于一般需要对图像做的变换不止一个,所以我们使用compose来对多个transforms进行组合在这里我们只需要一个ToTensor即可。

        下面代码给出使用compose定义transform和不使用compose的两个版本,都可以完成成功运行。

  •  使用compose:
import torchvision
#定义transforms
dataset_transform = torchvision.transforms.Compose([
    #定义totensor
    torchvision.transforms.ToTensor()
])
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,transform=dataset_transform,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,transform=dataset_transform,download=True)

  • 不使用compose:
import torchvision
#定义transforms
from torch.utils.tensorboard import SummaryWriter

trans_totensor_tool = torchvision.transforms.ToTensor()
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,transform=trans_totensor_tool,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,transform=trans_totensor_tool,download=True)
、

2.2 使用tensorboard进行图片显示

        完成了transform和数据集的定义后,即可使用add_image()方法完成图片显示。在这里我们使用for循环进行10张图片的显示。

import torchvision
#定义transforms
from torch.utils.tensorboard import SummaryWriter

trans_totensor_tool = torchvision.transforms.ToTensor()
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,transform=trans_totensor_tool,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,transform=trans_totensor_tool,download=True)

#使用tensorboard进行显示
writer = SummaryWriter("logs")
#for循环完成10张图片的显示
for i in range(10):
    img,label=test_set[i]
    writer.add_image("dataset",img,i)

writer.close()

        结果如下所示。可以看到一共step=9,成功显示了数据集中第1-10张图片。

          

          

四、练习:MNIST数据集的下载和使用

1. 可能的报错和修改 

        使用上面做过的练习对MNIST数据集进行相同的操作,注意在下载数据集后可能会爆“UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors.” 的错误,按照博文的方法修改即可。 

(4条消息) Pytorch | 报错The given NumPy array is not writeable,and PyTorch does not support non-writeable tensor_软耳朵DONG的博客-CSDN博客

2. 代码实现

对于PIL对象:

  • 完成数据集所有类别的输出(classes)
  • 输出数据集中的第一个对象
  • 完成前10张图片对应类别的输出
  • 完成第10张图片的显示(show方法)

对于tensor对象:

  • 把数据集中所有图片类型从PIL型转换为tensor型,重定义图片大小为10*10(使用Compose,ToTensor和Resize)
  • 输出前10张图片

2.1 PIL对象实现

import torchvision
from torch.utils.tensorboard import SummaryWriter


train_set = torchvision.datasets.MNIST(root="./MNIST_test",train=True,download=True)
test_set = torchvision.datasets.MNIST(root="./MNIST_test",train=False,download=True)

#pil型对象显示
print(test_set.classes)
print(test_set[0])
for i in range(10):
    img,label=test_set[i]
    print(test_set.classes[label])
img.show()

2.2 tensor对象实现

import torchvision
from torch.utils.tensorboard import SummaryWriter

trans_tool = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((10,10))
])

train_set = torchvision.datasets.MNIST(root="./MNIST_test",train=True,transform=trans_tool,download=True)
test_set = torchvision.datasets.MNIST(root="./MNIST_test",train=False,transform=trans_tool,download=True)

#tensor型对象显示
writer = SummaryWriter("logs")
for i in range(10):
    img,label=test_set[i]
    writer.add_image("MNIST",img,i)
print(img.shape)
writer.close()

3. 运行结果 

 数据集下载并创建成功:

 显示第10张图片:

 print的显示结果:

在未改变大小之前的维度是(1,28,28),resize后可见tensor的维度变成了(1,10,10 )

,

 tensoeboard显示结果: 

  • 9
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
首先,让我们先下载并导入 `torchvision` 和 `matplotlib` 库: ```python import torch import torchvision import matplotlib.pyplot as plt ``` 然后,我们可以使用以下代码加载 CIFAR10 数据集: ```python transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) ``` 在这里,我们定义了一个转换器,用于将 PIL 图像转换为 PyTorch 张量,并对像素值进行归一化。然后,我们使用 `torchvision.datasets.CIFAR10` 类加载数据集,`train=True` 表示加载训练集,`train=False` 表示加载测试集。`root` 参数指定数据集存储的文件夹,`download=True` 表示自动下载数据集。 接下来,我们可以使用以下代码显示数据集的一些示例图像: ```python classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') def imshow(img): img = img / 2 + 0.5 # 非归一化 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() dataiter = iter(trainloader) images, labels = dataiter.next() imshow(torchvision.utils.make_grid(images)) print(' '.join('%5s' % classes[labels[j]] for j in range(4))) ``` 在这里,我们首先定义了一个 `classes` 列表,包含 CIFAR10 的所有类别。然后,我们定义了一个函数 `imshow()`,用于显示图像。我们从训练集加载一批数据,并使用 `torchvision.utils.make_grid()` 函数将这些图像合并为一个网格。最后,我们使用 `plt.imshow()` 函数显示图像。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值