代码如下
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
import sys
#使用svg来显示图片, 使得清晰度变高
d2l.use_svg_display()
trans = transforms.ToTensor()#通过ToTensor将图像数据从PIL类型变换成Pytorch的tensor格式, 32位浮点数格式
#train=true代表的是下载的训练数据集, transform=trans代表需要得到的是pytorch的tensor格式并非图片, download代表默认从网络下载
mnist_train = torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=False)
mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=False)
def get_fashion_mnist_labels(labels):
#一个列表
text_labels = ["t-shirt","trouser","pullover","dress","coat","sandal","shirt","sneaker","bag","ankle boot"]
#将y的值与标签列表一对应, 最后返回字符串标签
return [text_labels[int(i)] for i in labels]
#不清楚实现, 反正一顿操作就能显示数据
def show_images(imgs,num_rows,num_cols,titles,scale=1.5):
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
#最后得到一个数组
for i, (ax, img) in enumerate(zip(axes, imgs)):
ax.imshow(img.numpy())
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
#data.DataLoader:打乱, 获取定长的数据, 此时没有打乱, 加shuffle打乱, 返回x和y的值
x,y = next(iter(data.DataLoader(mnist_train,batch_size=18)))
#x是用二进制表示的照片数据,y是最终结果
#reshape(18,28,28)代表18张28×28的图片
show_images(x.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y))
d2l.plt.show()
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据"""
return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,#num_worker代表使用CPU的几个核
num_workers=get_dataloader_workers())
#定义timer来获取读取数据的速度
timer = d2l.Timer()
timer.start()
#将数据集的内容分配到x和y中
for x, y in train_iter:
continue
timer.stop()
运行效果如图