代码块1
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l
运行后会报错:无tqdm,运行下边命令
pip install tqdm
然后会报错:无torchtext,运行下边命令
pip install torchtext -i https://pypi.tuna.tsinghua.edu.cn/simple
注:采用国内镜像源,不会更新torch版本
代码块2
# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
X, y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
运行后会报错:function object no attribute ‘set_matplotlib_formats’
原因是没有在代码块1中导入相应的包,在代码块1中加入以下代码
from IPython import display
然后就可以顺利运行。