先写解决办法好习惯
把报错的一行
plt.title(labels_map[label])
改成
plt.title(labels_map[label.item()])
问题、报错分析:
最近复制了pytorch官方文档的一段代码做实验
# 这段代码在Tutorials>Datasers&Dataloaders
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
发现一直报这样的错误:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-35-336d271805e0> in <module>
38 img, label = training_data[sample_idx]
39 figure.add_subplot(rows, cols, i)
---> 40 plt.title(labels_map[label])
41 plt.axis("off")
42 plt.imshow(img.squeeze(), cmap="gray")
KeyError: tensor(5)
错误在于labels_map是一个数组,数组索引用的是int型数,而label打印出来是诸如sensor(0)之类的张量,所以会报错