本文仅作为个人学习笔记用,如有错误请指正,欢迎大家讨论学习。本博客内容来自动手学深度学习
一、图像分类数据集
导包
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
from IPython import display
数据集下载
创建子目录下载训练集以及测试集数据
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
标签管理
接受label
参数后,使用列表推导式来生成并返回一个新的列表。对于输入列表 labels
中的每个元素 i
,它首先将 i
转换为整数(在大多数情况下,labels
中的元素可能已经是整数类型,这一步可能是为了确保兼容性或处理特殊情况),然后使用这个整数作为索引从 text_labels
列表中取出对应的文本标签。最终,这个表达式生成一个包含所有对应文本标签的新列表,并将其返回。
def get_fashion_mnist_labels(labels): #@save
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
样本可视化
figsize
由接收的 num_rows
, num_col
s决定,scale
为设置好的参数;
创建一个子图网格axes
,行列由num_rows
, num_cols
,其大小由figsize
确定;
axes.flatten()
将二维的子图数组转换为一维,以便后续遍历
接着,函数遍历每个子图和对应的图像,根据图像的类型,使用ax.imshow()
方法显示图像。
对于每个子图,还关闭了x轴和y轴的显示,并可选地设置了标题。
最后,函数返回包含所有子图的数组。
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""绘制图像列表"""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
else:
# PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
加载图像
从mnist_train
数据加载器中获取下一批数据,其中batch_size=18
意味着每次加载18个样本。X是图像数据,y是对应的标签。
在调用show_images
函数之前,先对X进行了重塑(reshape),因为从数据加载器获取的X可能是四维的(批量大小、通道数、高度、宽度),而Fashion-MNIST是灰度图像,通道数为1,所以这里将其重塑为三维(18个28x28的图像)。然后,指定了2行9列的布局来显示这些图像,并使用get_fashion_mnist_labels
函数将y中的数字标签转换为文本标签作为图像的标题。
最后,调用d2l.plt.show()
(“动手学深度学习”的库封装matplotlib的功能)来显示图像和标题。
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()
完整代码及运行结果
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
def get_fashion_mnist_labels(labels): #@save
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""绘制图像列表"""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
else:
# PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()