作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970
目录
第1章 TorchVision概述
1.1 TorchVision
Pytorch非常有用的工具集:
- torchtext:处理自然语言
- torchaudio:处理音频的
- torchvision:处理图像视频的。
torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。
1.2 TorchVision的安装
pip install torchvision
1.3 TorchVision官网的数据集
https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/
1.4 TorchVision常见的数据集概述
- MNIST
- CIFAR10
- CIFAR100
- COCO(用于图像标注和目标检测)(Captioning and Detection)
- LSUN Classification
- ImageNet flowers
- Imagenet-12
- STL10
第2章 CIFAR10数据集
2.1 数据集概述
CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。
该数据集共有60000张彩色图像,这些图像是32*32,分为10个类RGB 彩色三通道图 片,每类6000张图。
其中,50000张用于训练,构成了5个训练批次,每一批10000张图;
其中,10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批次。
注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。
CIFAR-10 的图片样例如图所示,包括
飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
2.2 与 MNIST 数据集比较
与 MNIST 数据集比较, CIFAR-10 具有以下不同点:
- CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
- CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
- 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。
- 直接的全连接的线性模型,即使在MNIST表现良好,在 CIFAR-10数据集上表现得很差。
2.3 下载地址
官方下载地址:(很慢)
一共有三个版本:python,matlab,binary version 适用于C语言
http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz
http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
第3章 TorchVision对CIFAR10的支持
3.1 函数原型
CIFAR10 (root, train=True, transform=None, target_transform=None, download=False)
- root:存储数据集的根目录
- train=True or false:训练集还是测试集
- transform=None:在加载数据前的格式转换
- target_transform=None:
- download=False:是否需要在线下载
3.2 数据下载前的准备
#环境准备
import numpy as np # numpy数组库
import math # 数学运算库
import matplotlib.pyplot as plt # 画图库
import torch # torch基础库
import torchvision.datasets as dataset #公开数据集的下载和管理
import torchvision.transforms as transforms #公开数据集的预处理库,格式转换
import torchvision.utils as utils
import torch.utils.data as data_utils #对数据集进行分批加载的工具集
print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
3.3 数据集下载与导入
如果本地没有数据集,会自动远程下载
#2-1 准备数据集
train_data = dataset.CIFAR10 (root = "cifar10",
train = True,
transform = transforms.ToTensor(),
download = True)
#2-1 准备数据集
test_data = dataset.MNIST(root = "cifar10",
train = False,
transform = transforms.ToTensor(),
download = True)
print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\cifar-10-python.tar.gz Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\cifar-10-python.tar.gz
100.0%
Extracting cifar10\cifar-10-python.tar.gz to cifar10
1.1%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-images-idx3-ubyte.gz to cifar10\MNIST\raw\train-images-idx3-ubyte.gz
100.0%
Extracting cifar10\MNIST\raw\train-images-idx3-ubyte.gz to cifar10\MNIST\raw
102.8%
Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-labels-idx1-ubyte.gz to cifar10\MNIST\raw\train-labels-idx1-ubyte.gz Extracting cifar10\MNIST\raw\train-labels-idx1-ubyte.gz to cifar10\MNIST\raw
5.0%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-images-idx3-ubyte.gz to cifar10\MNIST\raw\t10k-images-idx3-ubyte.gz
100.0%
Extracting cifar10\MNIST\raw\t10k-images-idx3-ubyte.gz to cifar10\MNIST\raw Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-labels-idx1-ubyte.gz to cifar10\MNIST\raw\t10k-labels-idx1-ubyte.gz
112.7%
Extracting cifar10\MNIST\raw\t10k-labels-idx1-ubyte.gz to cifar10\MNIST\raw Processing...
C:\ProgramData\Anaconda3\envs\pytorch1.8_py3.8\lib\site-packages\torchvision\datasets\mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\torch\csrc\utils\tensor_numpy.cpp:143.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Done! Dataset CIFAR10 Number of datapoints: 50000 Root location: cifar10 Split: Train StandardTransform Transform: ToTensor() size= 50000 Dataset MNIST Number of datapoints: 10000 Root location: cifar10 Split: Test StandardTransform Transform: ToTensor() size= 10000
3.4 显示单张样本图片
#原图不叠加噪声
#获取一张图片数据
print("原始Pytorch图片")
image, label = train_data[2]
print("torch image shape:", image.shape)
print("torch image label:", label)
print("\n通道转换后的Numpy图片")
image = image.numpy().transpose(1,2,0) #交换维度,从GBR换成RGB
print("numpy image shape:", image.shape)
print("numpy image label:", label)
plt.imshow(image)
plt.show()
3.5 启动loader对象
# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,
batch_size = 8,
shuffle = True)
test_loader = data_utils.DataLoader(dataset = test_data,
batch_size = 8,
shuffle = True)
print(train_loader)
print(test_loader)
print(len(train_data), len(train_data)/8)
print(len(test_data), len(test_data)/8)
<torch.utils.data.dataloader.DataLoader object at 0x0000012C3BBA85E0> <torch.utils.data.dataloader.DataLoader object at 0x0000012C3BBA8F40> 50000 6250.0 10000 1250.0
3.6 显示批量图片
pytorch对图片的格式定义与Numpy对图片的格式定义是不一样的。
因此需要通过transpose()进行维度的变换。
#显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])
print("\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs, nrow = 4)
print(images.shape)
print(labels.shape)
print("\n转换成imshow格式")
images = images.numpy().transpose(1,2,0)
print(images.shape)
print(labels.shape)
print("\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片 torch.Size([8, 3, 32, 32]) torch.Size([8]) 8 合并成一张三通道灰度图片 torch.Size([3, 70, 138]) torch.Size([8]) 转换成imshow格式 (70, 138, 3) torch.Size([8]) 显示图片
第4章 CIFAR100与CIFAR10的比较
4.1 相同点
采用相同的图片布局:3 * 32 * 32 = 3072
4.2 不同点
- 有100个类,每个类包含600个图像。
- 每类各有500个训练图像和100个测试图像。
- CIFAR-100中的100个类被分成20个超类。
- 每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类)
以下是CIFAR-100中的20个超类别以及对应的子类:
超类 | 类别 |
---|---|
水生哺乳动物 | 海狸,海豚,水獭,海豹,鲸鱼 |
鱼 | 水族馆的鱼,比目鱼,射线,鲨鱼,鳟鱼 |
花卉 | 兰花,罂粟花,玫瑰,向日葵,郁金香 |
食品容器 | 瓶子,碗,罐子,杯子,盘子 |
水果和蔬菜 | 苹果,蘑菇,橘子,梨,甜椒 |
家用电器 | 时钟,电脑键盘,台灯,电话机,电视机 |
家用家具 | 床,椅子,沙发,桌子,衣柜 |
昆虫 | 蜜蜂,甲虫,蝴蝶,毛虫,蟑螂 |
大型食肉动物 | 熊,豹,狮子,老虎,狼 |
大型人造户外用品 | 桥,城堡,房子,路,摩天大楼 |
大自然的户外场景 | 云,森林,山,平原,海 |
大杂食动物和食草动物 | 骆驼,牛,黑猩猩,大象,袋鼠 |
中型哺乳动物 | 狐狸,豪猪,负鼠,浣熊,臭鼬 |
非昆虫无脊椎动物 | 螃蟹,龙虾,蜗牛,蜘蛛,蠕虫 |
人 | 宝贝,男孩,女孩,男人,女人 |
爬行动物 | 鳄鱼,恐龙,蜥蜴,蛇,乌龟 |
小型哺乳动物 | 仓鼠,老鼠,兔子,母老虎,松鼠 |
树木 | 枫树,橡树,棕榈,松树,柳树 |
车辆1 | 自行车,公共汽车,摩托车,皮卡车,火车 |
车辆2 | 割草机,火箭,有轨电车,坦克,拖拉机 |
第5章 图片集的手工下载
5.1 CIFAR10
CIFAR-10 python版本
CIFAR-10 Matlab版本
CIFAR-10二进制版本(适用于C程序)
5.2 CIFAR100
CIFAR-100 python版本
CIFAR-100 Matlab版本
CIFAR-100二进制版本(适用于C程序)
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970