3.5 图像分类数据集
这章主要是讲对数据的一些处理和准备
数据集:使用类似但更复杂的Fashion-MNIST数据集
%matplotlib inline
import torch
# torchvision:pytorch视觉实现的一个库
import torchvision
#from torchvision import datasets
from torch.utils import data
# transforms对数据进行操作的模组
from torchvision import transforms
from d2l import torch as d2l
# 这个函数也是画图的函数
d2l.use_svg_display()
① 将数据从torch库中下载下来,并转变成pytorch的数据类型torchvision.datasets
torchvision
:https://blog.csdn.net/wohu1104/article/details/107743290
知识点1; 我们使用是配套的
- 如果导入的是
import torchvision
,那么使用你面的函数就是torchvision.datasets.XXX; - 如果导入的是
from torchvision import datasets
,那么使用你面的函数就是datasets.XXX;
同理torchvision下的主要函数transform
# 通过ToTensor实例将图像数据从PIL类型变换成float32格式,
# 并除以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()
# 准备训练集 和测试机
# 训练集下载的地方;是否是训练集;下载时候转变成32位浮点数格式;是否下载
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
- 【观察数据】
Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。
# 看看训练集 测试集长度
len(mnist_train), len(mnist_test)
(60000, 10000)
个输入图像的高度和宽度均为28像素。 数据集由灰度图像组成,其通道数为1
。 为了简洁起见,本书将高度h像素、宽度w像素图像的形状记为或(h,w)
mnist_train[0][0].shape
torch.Size([1, 28, 28])
- mnist_train[0][0]第一个是取第几张照片,范围是0~59999;60000会报错,因为一共只有60000张图
- 第二个取值
0
:图片数据;1
:标签数据
下面是第1张图属于第9
类;第60000张图属于第5
类
mnist_train[0][1]
9
mnist_train[59999][1]
5
② 处理分类,用get_fashion_mnist_labels
将分类数字0-9
与具体名称一一对应
Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。 以下函数用于在数字标签索引及其文本名称之间进行转换。
def get_fashion_mnist_labels(labels): #@save
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
③ 名称有了,用show_images
来呈现图像
参数:图片 展示成几行几列 标题默认没有 规模(尺寸大小)
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
# 设置图片尺寸:scale就是每张图的大小(基数),这里是1.5;改大图像就变大
figsize = (num_cols * scale, num_rows * scale)
# 这里的 _ 表示忽略不使用的变量、即fig;
# d2l.plt.subplots()把多张图拼成一张,其中figsize把上面尺寸传下来
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
#把一张图的数据拉直(成):在用plt.subplots画多个子图中,ax = ax.flatten()将ax由n*m的Axes组展平成1*nm的Axes组
# 这一步就是 不用在二维数组来定位第几个图了 直接第几个图,会自动落位?
axes = axes.flatten()
#print("axes",axes[1][0])
for i, (ax, img) in enumerate(zip(axes, imgs)):
# i:第几个图
# imgs:是一个图的数据,是个28*28的二维数组
# 判断传入的图片是否为张量
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
else:
# PIL图片
ax.imshow(img)
# 设取消横纵坐标上的刻度(横、纵轴均为28)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
-
把多张图拼成一张图的函数:
fig, ax = plt.subplots()
:https://www.cnblogs.com/komean/p/10670619.htmlfig, ax = plt.subplots(1,3)
: 其中参数1和3分别代表子图的行数和列数,一共有 1x3 个子图像。函数返回一个figure图像和子图ax的array列表。fig, ax = plt.subplots(1,3,1)
: 最后一个参数1代表第一个子图。
如果想要设置子图的宽度和高度可以在函数内加入figsize值fig, ax = plt.subplots(1,3,figsize=(15,7))
: 这样就会有1行3个15x7大小的子图。【本文用法】 -
axes = axes.flatten()
:https://blog.csdn.net/weixin_38314865/article/details/84785141 -
zip()
:https://blog.csdn.net/lanmy_dl/article/details/124216431 -
enumerate()
遍历: https://www.runoob.com/python/python-func-enumerate.html -
结合使用;https://blog.csdn.net/weixin_43408110/article/details/87731547
④ 来设计上一步的X:imags
和y:titles
# 取出X y 用next(iter)搞了第一组数据,数据一组18个
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
#我们一个`X:torch.Size([18, 1, 28, 28])`传进来18张照片 每张[1*28*28]
X.shape
torch.Size([18, 1, 28, 28])
# 不要色彩通道了 直接这组张数 和 每张照片的尺寸;一排9张 一o共2排
show_images(X.reshape(18,28, 28), 2,9, titles=get_fashion_mnist_labels(y));
3.5.2 读取小批量
# 准备训练数据
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据"""
return 4
# batch_size一组256个,shuffle随机取,几个进程做
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
# 测试了一下训练的时间
timer = d2l.Timer()
for X, y in train_iter:
continue
f'{
timer.stop():.2f} sec'
'5.81 sec'
3.5.3. 整合所有组件
# resize 就是我们输入照片是28*28,如果想把尺寸变大 用到resize
def load_data_fashion_mnist(batch_size, resize=None): #@save