概念
CIFAR10数据集
CIFAR-10(Canadian Institute for Advanced Research-10)是一个常用的计算机视觉数据集,用于图像分类任务。它由60000个32x32彩色图像组成,这些图像均来自于10个不同的类别,每个类别包含6000个图像。数据集被分为两个部分:训练集和测试集,其中训练集包含50000个图像,测试集包含10000个图像。
CIFAR-10数据集中的图像涵盖了广泛的对象类别,包括飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。每个图像都有一个标签,表示它所属的类别。这个数据集被广泛用于计算机视觉领域的算法开发、模型训练和性能评估。
由于图像尺寸较小且类别数较少,CIFAR-10数据集通常用于快速验证和原型开发,以及用于学习和理解各种计算机视觉算法的基本原理。它也被用作深度学习模型的基准数据集,用于评估模型在图像分类任务上的性能和泛化能力。
实战
用torchvision加载CIFAR-10数据集,用transforms转换成我们所需要的数据格式,并通过dataloader加载,tensorboard进行展示
1、torchvision加载
# root="./data"表示将数据加载到你当前工作目录的data下,train=True表示所下载的数据集为测试数据集,download=True表示如果不存在则自动下载
train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True)
下载过程中会出现下载进度
如果已下载则会出现
展示
print(test_set[0])
输出test_set[0]第一张图片的格式以及它对应的类别
print(train_set.classes)
输出train_set所包含的所有类别
img, target = test_set[0]
img.show()
# 输出第一张图片对应的target
print(target)
# 输出这张target对应的具体类别是什么
print(train_set.classes[target])
2、transforms转换
torchvision.datasets.CIFAR10中的参数有很多,可以加上transforms将下载的数据集进行转换供我们后续训练
transform=torchvision.transforms.ToTensor将数据集转换为Tensor格式,此时的img就无法show()了
train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor())
test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=torchvision.transforms.ToTensor())
3、dataloader加载
# dataset=train_set表示要加载的数据集,batch_size=4表示一次加载多少张图片,shuffle=True是否将图片的顺序打乱,num_workers=0一共有多少线程进行加载,drop_last=False最后一批数据不够定义的batch_size的话则删除
train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset=train_set, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset=test_set, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
for data in train_loader:
imgs, target = data
print(imgs.shape)
print(target)
torch.Size中的4表示一批一共4张图片,3表示图片为3通道,32表示为32*32大小
tensor([0, 5, 3, 7])表示这4张图片的类别
4、tensorboard展示
# 读取目录
writer = SummaryWriter("dataloader")
step = 0
for data in train_loader:
imgs, target = data
# print(imgs.shape)
# print(target)
writer.add_images("test_dataloader", imgs, step)
step = step + 1
writer.close()
启动tensorboard查看
tensorboard --logdir=dataloader目录 --port=6007
总结
大概流程就是加载然后通过transofrms进行转换,转换为tensor格式后加载,然后用自己写的网络进行训练
由于目前正在学习cv这一块故整理笔记
参考:
https://www.bilibili.com/video/BV1hE411t7RN?p=15&vd_source=6da4443c90ff512f68d386f6607eadab