pytorch入门之CIFAR10数据集

本文介绍了CIFAR10数据集在计算机视觉任务中的作用,如何使用torchvision库加载并预处理数据,以及如何使用DataLoader和TensorBoard进行可视化。教程详细展示了从下载数据到训练过程的实践步骤。
摘要由CSDN通过智能技术生成

概念

CIFAR10数据集
CIFAR-10(Canadian Institute for Advanced Research-10)是一个常用的计算机视觉数据集,用于图像分类任务。它由60000个32x32彩色图像组成,这些图像均来自于10个不同的类别,每个类别包含6000个图像。数据集被分为两个部分:训练集和测试集,其中训练集包含50000个图像,测试集包含10000个图像。

CIFAR-10数据集中的图像涵盖了广泛的对象类别,包括飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。每个图像都有一个标签,表示它所属的类别。这个数据集被广泛用于计算机视觉领域的算法开发、模型训练和性能评估。

由于图像尺寸较小且类别数较少,CIFAR-10数据集通常用于快速验证和原型开发,以及用于学习和理解各种计算机视觉算法的基本原理。它也被用作深度学习模型的基准数据集,用于评估模型在图像分类任务上的性能和泛化能力。

实战

用torchvision加载CIFAR-10数据集,用transforms转换成我们所需要的数据格式,并通过dataloader加载,tensorboard进行展示

1、torchvision加载

# root="./data"表示将数据加载到你当前工作目录的data下,train=True表示所下载的数据集为测试数据集,download=True表示如果不存在则自动下载
train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True)

test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True)

下载过程中会出现下载进度
如果已下载则会出现
在这里插入图片描述
展示

print(test_set[0])

输出test_set[0]第一张图片的格式以及它对应的类别
在这里插入图片描述

print(train_set.classes)

输出train_set所包含的所有类别
在这里插入图片描述

img, target = test_set[0]
img.show()
# 输出第一张图片对应的target
print(target)
# 输出这张target对应的具体类别是什么
print(train_set.classes[target])

img.show()
在这里插入图片描述

2、transforms转换

torchvision.datasets.CIFAR10中的参数有很多,可以加上transforms将下载的数据集进行转换供我们后续训练
在这里插入图片描述
transform=torchvision.transforms.ToTensor将数据集转换为Tensor格式,此时的img就无法show()了

train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor())

test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=torchvision.transforms.ToTensor())

在这里插入图片描述

3、dataloader加载

# dataset=train_set表示要加载的数据集,batch_size=4表示一次加载多少张图片,shuffle=True是否将图片的顺序打乱,num_workers=0一共有多少线程进行加载,drop_last=False最后一批数据不够定义的batch_size的话则删除
train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset=train_set, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset=test_set, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
for data in train_loader:
    imgs, target = data
    print(imgs.shape)
    print(target)

在这里插入图片描述
torch.Size中的4表示一批一共4张图片,3表示图片为3通道,32表示为32*32大小
tensor([0, 5, 3, 7])表示这4张图片的类别

4、tensorboard展示

# 读取目录
writer = SummaryWriter("dataloader")

step = 0
for data in train_loader:
    imgs, target = data
    # print(imgs.shape)
    # print(target)
    writer.add_images("test_dataloader", imgs, step)
    step = step + 1

writer.close()

启动tensorboard查看
tensorboard --logdir=dataloader目录 --port=6007
在这里插入图片描述

总结

大概流程就是加载然后通过transofrms进行转换,转换为tensor格式后加载,然后用自己写的网络进行训练

由于目前正在学习cv这一块故整理笔记

参考:
https://www.bilibili.com/video/BV1hE411t7RN?p=15&vd_source=6da4443c90ff512f68d386f6607eadab

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值