#两个可视化数据集的函数
def get_fashion_mnist_labels(labels):
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = [ 't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal','shirt'
'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
"""plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
#图片张量
ax.imshow(img.numpy())
else:
#PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
IndexError Traceback (most recent call last) <ipython-input-31-062a59775dcb> in <module> 1 #几个样本的图像及其相对应的标签 2 X, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) ----> 3 show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y)); 4 '''X,y = [],[] 5 #初始化两个列表 <ipython-input-27-9703a7aece05> in get_fashion_mnist_labels(labels) 4 text_labels = [ 't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal','shirt' 5 'sneaker', 'bag', 'ankle boot'] ----> 6 return [text_labels[int(i)] for i in labels] 7 8 def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): <ipython-input-27-9703a7aece05> in <listcomp>(.0) 4 text_labels = [ 't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal','shirt' 5 'sneaker', 'bag', 'ankle boot'] ----> 6 return [text_labels[int(i)] for i in labels] 7 8 def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): IndexError: list index out of range