[Pytorch系列-33]:数据集 - torchvision与CIFAR10/CIFAR100详解

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970


目录

第1章 TorchVision概述

1.1 TorchVision

1.2 TorchVision的安装

1.3 TorchVision官网的数据集

1.4 TorchVision常见的数据集概述

第2章 CIFAR10数据集

2.1 数据集概述

2.2 与 MNIST 数据集比较

2.3 下载地址

第3章 TorchVision对CIFAR10的支持

3.1 函数原型

3.2 数据下载前的准备

3.3 数据集下载与导入

3.4 显示单张样本图片

3.5 启动loader对象

3.6 显示批量图片

第4章 CIFAR100与CIFAR10的比较

4.1 相同点

4.2 不同点

第5章 图片集的手工下载

5.1 CIFAR10

5.2 CIFAR100



第1章 TorchVision概述

1.1 TorchVision

Pytorch非常有用的工具集:

  • torchtext:处理自然语言
  • torchaudio:处理音频的
  • torchvision:处理图像视频的。

torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。

1.2 TorchVision的安装

pip install torchvision 

1.3 TorchVision官网的数据集

https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/

1.4 TorchVision常见的数据集概述

  • MNIST
  • CIFAR10
  • CIFAR100
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageNet flowers
  • Imagenet-12
  • STL10

第2章 CIFAR10数据集

2.1 数据集概述

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。

该数据集共有60000张彩色图像,这些图像是32*32,分为10个类RGB 彩色三通道图 片,每类6000张图。

其中,50000张用于训练,构成了5个训练批次,每一批10000张图;

其中,10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批次。

注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。

CIFAR-10 的图片样例如图所示,包括

飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
å¨è¿éæå¥å¾çæè¿°

2.2 与 MNIST 数据集比较

与 MNIST 数据集比较, CIFAR-10 具有以下不同点:

  • CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
  • CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
  • 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。
  • 直接的全连接的线性模型,即使在MNIST表现良好,在 CIFAR-10数据集上表现得很差。

2.3 下载地址

官方下载地址:(很慢)

一共有三个版本:python,matlab,binary version 适用于C语言

http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz

http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz

第3章 TorchVision对CIFAR10的支持

3.1 函数原型

CIFAR10 (root, train=True, transform=None, target_transform=None, download=False)

  • root:存储数据集的根目录
  • train=True or false:训练集还是测试集
  • transform=None:在加载数据前的格式转换
  • target_transform=None:
  • download=False:是否需要在线下载

3.2 数据下载前的准备

#环境准备
import numpy as np              # numpy数组库
import math                     # 数学运算库
import matplotlib.pyplot as plt # 画图库

import torch             # torch基础库
import torchvision.datasets as dataset  #公开数据集的下载和管理
import torchvision.transforms as transforms  #公开数据集的预处理库,格式转换
import torchvision.utils as utils 
import torch.utils.data as data_utils  #对数据集进行分批加载的工具集

print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())

3.3 数据集下载与导入

如果本地没有数据集,会自动远程下载

#2-1 准备数据集
train_data = dataset.CIFAR10 (root = "cifar10",
                           train = True,
                           transform = transforms.ToTensor(),
                           download = True)

#2-1 准备数据集
test_data = dataset.MNIST(root = "cifar10",
                           train = False,
                           transform = transforms.ToTensor(),
                           download = True)

print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\cifar-10-python.tar.gz
Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\cifar-10-python.tar.gz
100.0%
Extracting cifar10\cifar-10-python.tar.gz to cifar10
1.1%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-images-idx3-ubyte.gz to cifar10\MNIST\raw\train-images-idx3-ubyte.gz
100.0%
Extracting cifar10\MNIST\raw\train-images-idx3-ubyte.gz to cifar10\MNIST\raw
102.8%
Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-labels-idx1-ubyte.gz to cifar10\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting cifar10\MNIST\raw\train-labels-idx1-ubyte.gz to cifar10\MNIST\raw
5.0%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-images-idx3-ubyte.gz to cifar10\MNIST\raw\t10k-images-idx3-ubyte.gz
100.0%
Extracting cifar10\MNIST\raw\t10k-images-idx3-ubyte.gz to cifar10\MNIST\raw
Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-labels-idx1-ubyte.gz to cifar10\MNIST\raw\t10k-labels-idx1-ubyte.gz
112.7%
Extracting cifar10\MNIST\raw\t10k-labels-idx1-ubyte.gz to cifar10\MNIST\raw
Processing...
C:\ProgramData\Anaconda3\envs\pytorch1.8_py3.8\lib\site-packages\torchvision\datasets\mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ..\torch\csrc\utils\tensor_numpy.cpp:143.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Done!
Dataset CIFAR10
    Number of datapoints: 50000
    Root location: cifar10
    Split: Train
    StandardTransform
Transform: ToTensor()
size= 50000

Dataset MNIST
    Number of datapoints: 10000
    Root location: cifar10
    Split: Test
    StandardTransform
Transform: ToTensor()
size= 10000

3.4 显示单张样本图片

#原图不叠加噪声
#获取一张图片数据
print("原始Pytorch图片")
image, label = train_data[2]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n通道转换后的Numpy图片")
image = image.numpy().transpose(1,2,0)  #交换维度,从GBR换成RGB
print("numpy image shape:", image.shape)
print("numpy image label:", label)

plt.imshow(image)
plt.show()

3.5 启动loader对象

# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,
                                  batch_size = 8,
                                  shuffle = True)

test_loader = data_utils.DataLoader(dataset = test_data,
                                  batch_size = 8,
                                  shuffle = True)

print(train_loader)
print(test_loader)
print(len(train_data), len(train_data)/8)
print(len(test_data),  len(test_data)/8)
<torch.utils.data.dataloader.DataLoader object at 0x0000012C3BBA85E0>
<torch.utils.data.dataloader.DataLoader object at 0x0000012C3BBA8F40>
50000 6250.0
10000 1250.0

3.6 显示批量图片

pytorch对图片的格式定义与Numpy对图片的格式定义是不一样的。

因此需要通过transpose()进行维度的变换。

#显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])

print("\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs, nrow = 4)
print(images.shape)
print(labels.shape)

print("\n转换成imshow格式")
images = images.numpy().transpose(1,2,0) 
print(images.shape)
print(labels.shape)

print("\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片
torch.Size([8, 3, 32, 32])
torch.Size([8])
8

合并成一张三通道灰度图片
torch.Size([3, 70, 138])
torch.Size([8])

转换成imshow格式
(70, 138, 3)
torch.Size([8])

显示图片

第4章 CIFAR100与CIFAR10的比较

4.1 相同点

采用相同的图片布局:3 * 32 * 32 = 3072

4.2 不同点

  • 有100个类,每个类包含600个图像。
  • 每类各有500个训练图像和100个测试图像。
  • CIFAR-100中的100个类被分成20个超类。
  • 每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类

以下是CIFAR-100中的20个超类别以及对应的子类:

超类类别
水生哺乳动物海狸,海豚,水獭,海豹,鲸鱼
水族馆的鱼,比目鱼,射线,鲨鱼,鳟鱼
花卉兰花,罂粟花,玫瑰,向日葵,郁金香
食品容器瓶子,碗,罐子,杯子,盘子
水果和蔬菜苹果,蘑菇,橘子,梨,甜椒
家用电器时钟,电脑键盘,台灯,电话机,电视机
家用家具床,椅子,沙发,桌子,衣柜
昆虫蜜蜂,甲虫,蝴蝶,毛虫,蟑螂
大型食肉动物熊,豹,狮子,老虎,狼
大型人造户外用品桥,城堡,房子,路,摩天大楼
大自然的户外场景云,森林,山,平原,海
大杂食动物和食草动物骆驼,牛,黑猩猩,大象,袋鼠
中型哺乳动物狐狸,豪猪,负鼠,浣熊,臭鼬
非昆虫无脊椎动物螃蟹,龙虾,蜗牛,蜘蛛,蠕虫
宝贝,男孩,女孩,男人,女人
爬行动物鳄鱼,恐龙,蜥蜴,蛇,乌龟
小型哺乳动物仓鼠,老鼠,兔子,母老虎,松鼠
树木枫树,橡树,棕榈,松树,柳树
车辆1自行车,公共汽车,摩托车,皮卡车,火车
车辆2割草机,火箭,有轨电车,坦克,拖拉机

第5章 图片集的手工下载

5.1 CIFAR10

CIFAR-10 python版本
CIFAR-10 Matlab版本
CIFAR-10二进制版本(适用于C程序)

5.2 CIFAR100

CIFAR-100 python版本
CIFAR-100 Matlab版本
CIFAR-100二进制版本(适用于C程序)


作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970

  • 4
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
首先,让我们先下载和导入必要的PyTorchtorchvision库: ```python import torch import torchvision import torchvision.transforms as transforms ``` 接下来,我们可以定义一些数据转换,以便将CIFAR10图像的像素值转换为张量,并对它们进行标准化。我们还可以将数据集分成训练集和测试集。 ```python transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) ``` 现在,我们可以显示一些图像来检查它们是否已成功加载。我们将使用matplotlib库来绘制图像。 ```python import matplotlib.pyplot as plt import numpy as np # 定义类别标签 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 随机获取一些训练图像 dataiter = iter(trainloader) images, labels = dataiter.next() # 绘制图像 def imshow(img): img = img / 2 + 0.5 # 反归一化 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # 显示图像 imshow(torchvision.utils.make_grid(images)) # 输出标签 print(' '.join('%5s' % classes[labels[j]] for j in range(4))) ``` 这将显示四张训练图片和它们的标签。现在,我们已经成功地加载并显示了CIFAR10数据集,可以开始使用PyTorch进行图像分类任务。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

文火冰糖的硅基工坊

你的鼓励是我前进的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值