介绍
进入pytorch官网,首页Docs下有不同的库
进入torchvision库,可以查看开源数据集
CIFAR10数据集
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
# 使用参数
torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
下载数据集
# 下载数据集并保存在dataset文件夹中
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
如果下载较慢,可以复制下载链接用迅雷下载
下载好将数据集压缩包放入"./dataset"中,直接运行当前Python文件,数据集可以直接解压
如果没有显示下载链接,ctrl+单击当前数据集名称,查看数据集源代码也可以找到数据集的下载链接
打印数据集的第一个元素
import torchvision
# 下载数据集并保存在dataset文件夹中
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
print(test_set[0]) # 打印数据集中的第一个元素
print(test_set.classes) #打印数据集的classes属性
<PIL.Image.Image image mode=RGB size=32x32 at 0x2AB0D45F2B0>是图片信息,3是classes属性的’cat’类
打开图片
import torchvision
# 下载数据集并保存在dataset文件夹中
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
print(test_set[0]) # 打印数据集中的第一个元素
print(test_set.classes) #打印数据集的classes属性
img, target = test_set[0] # 将数据集中的信息保存在img, 和target中
print(img)
print(target)
print(test_set.classes[target])
img.show() # 打开图片
与transform结合使用
import torchvision
from torch.utils.tensorboard import SummaryWriter
# 将PIL图片类型转换为tensor数据类型
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 下载数据集并将数据集中的每一张图片都转换为Tensor数据类型
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
writer = SummaryWriter("logs")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()