08 torchvision中的数据集使用

打开PyTorch官网,官方文档中第一部分是PyTorch的核心模块,torchaudio是处理PyTorch语音的,torchtext是处理文本的,torchivision是处理图像的。

在这里插入图片描述

打开torchvision,tensorboard和transforms均来源于这里,torchvision分了好几个模块,包括Datasets即数据集的API文档,只要在写代码时指定相应数据集的参数,它就能去下载使用对应的数据集。

在这里插入图片描述

COCO数据集一般用于目标检查、语义分割;MINIST一般作为教科书中的入门数据集,是手写文字数据集;CIFAR一般用于物体识别。

torchvision.models中提供了最常用的一些神经网络模块,这些模块已经训练好了。其中有分类、语义分割、目标检测、视频分类等数据集。

torchvision中的tensorboard和transforms模块已经讲解过了。

CIFAR 10 Dataset

  • root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.

  • train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.

  • transform (callable_,_ optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable_,_ optional) – A function/transform that takes in the target and transforms it.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  1. root表示数据集存在什么样的位置。
  2. train是bool变量,为true表示创建的是训练集,为false表示是测试集。
  3. transform表示想对数据集进行什么样的变换。
  4. target_transform是对target进行transform。
  5. download为true表示从网上自动下载数据集,为false则不会下载。

首先导入torchvision

import torchvison

调用datasets工具包,对CIFAR10数据集进行下载

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

我们可以对下载得到的数据集进行分析

print(test_set[0])

输出:
(<PIL.Image.Image image mode=RGB size=32x32 at 0x1C58D023F40>, 3)

第一部分为PIL的图片数据,第二部分的3代表一个target(类别)。

print(test_set.classes)

输出:['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

说明3代表的是猫这个类别。

知道了这些后,我们就可以用两个变量来获取test_set的数据

img, target = test_set[0]
print(img)
print(target)

输出:
<PIL.Image.Image image mode=RGB size=32x32 at 0x1C583EE91C0>
3

CIFAR10包含了60000张32×32分为10个类别的彩色图片,其中50000张是训练图像,10000张是测试图像。

利用Transforms进行类型转换

先实例化transform对象

dataset_transform = torchvision.transform.Compose([torchvision.transforms.ToTensor()])

然后可以在数据集的参数中加入transform参数

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)  
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

输出图片

print(test_set[0])

输出:tensor数据类型

和Tensorboard结合

writer = SummaryWriter("p10")
for i in range(10):
	img, target = test_set[i]
	writer.add_image("test_set", img, i)
writer.close()

打开tensorboard,可以发现生成了step1-10的图片。

在这里插入图片描述

其他数据集的参数也类似,可以通过官方文档查看数据集参数,但有些数据集下载的很慢,我们可以先设置download为True,然后打开下载地址用迅雷(或其他)下载速度快的软件下载,下载完成后保存到对应文件夹中,这时再运行代码,它就会调用已经下载好的压缩包进行解压,减少下载所需时间。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值