torchvision中数据集的使用

介绍

进入pytorch官网,首页Docs下有不同的库
在这里插入图片描述进入torchvision库,可以查看开源数据集
在这里插入图片描述

CIFAR10数据集

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

# 使用参数
torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

下载数据集

# 下载数据集并保存在dataset文件夹中
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

如果下载较慢,可以复制下载链接用迅雷下载
在这里插入图片描述下载好将数据集压缩包放入"./dataset"中,直接运行当前Python文件,数据集可以直接解压
如果没有显示下载链接,ctrl+单击当前数据集名称,查看数据集源代码也可以找到数据集的下载链接
在这里插入图片描述

打印数据集的第一个元素

import torchvision

# 下载数据集并保存在dataset文件夹中
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

print(test_set[0]) # 打印数据集中的第一个元素
print(test_set.classes) #打印数据集的classes属性

在这里插入图片描述<PIL.Image.Image image mode=RGB size=32x32 at 0x2AB0D45F2B0>是图片信息,3是classes属性的’cat’类

打开图片

import torchvision

# 下载数据集并保存在dataset文件夹中
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

print(test_set[0]) # 打印数据集中的第一个元素
print(test_set.classes) #打印数据集的classes属性

img, target = test_set[0] # 将数据集中的信息保存在img, 和target中
print(img)
print(target)
print(test_set.classes[target])
img.show() # 打开图片

在这里插入图片描述

在这里插入图片描述

与transform结合使用

import torchvision
from torch.utils.tensorboard import SummaryWriter

# 将PIL图片类型转换为tensor数据类型
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
# 下载数据集并将数据集中的每一张图片都转换为Tensor数据类型
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

writer = SummaryWriter("logs")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值