RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
- 在跑Pytroch的MNIST手写识别例子时,碰到了shape不匹配的错误,错误指向:
images, labels = next(iter(data_loader_train))
![](https://i-blog.csdnimg.cn/blog_migrate/f5ff6c618eaeea3821ce3daba37e2d0b.png)
- 在尝试过多次之后,发现错误并不是这一句引发的,而是因为图片格式是灰度图只有一个channel,需要变成RGB图才可以,所以将其中一行做了修改:
- 修改前:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
- 修改后:
# 引入库
import torch
from torchvision import datasets, transforms
import torchvision.transforms
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3,1,1)),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]) # 修改的位置
data_train=datasets.MNIST(root="./data",
transform=transform,
train=True,
download=True
)
data_test=datasets.MNIST(root="./data",
transform=transform,
train=False)
data_loader_train=torch.utils.data.DataLoader(dataset=data_train,
batch_size=64,
shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,
batch_size=64,
shuffle=True)
images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print([labels[i] for i in range(64)])
plt.imshow(img)
- 结果:可以看到输出的首先是64张图片对应的标签,然后是64张图片的预览结果。[tensor(8), tensor(1), tensor(7), tensor(1), tensor(8), tensor(0), tensor(6), tensor(7), tensor(1), tensor(7), tensor(1), tensor(2), tensor(5), tensor(8), tensor(5), tensor(4), tensor(3), tensor(7), tensor(8), tensor(5), tensor(1), tensor(8), tensor(3), tensor(0), tensor(8), tensor(4), tensor(2), tensor(0), tensor(9), tensor(0), tensor(6), tensor(3), tensor(9), tensor(3), tensor(6), tensor(1), tensor(1), tensor(5), tensor(2), tensor(7), tensor(0), tensor(7), tensor(4), tensor(0), tensor(1), tensor(4), tensor(8), tensor(8), tensor(7), tensor(4), tensor(5), tensor(1), tensor(2), tensor(7), tensor(3), tensor(5), tensor(1), tensor(2), tensor(7), tensor(8), tensor(2), tensor(8), tensor(4), tensor(4)]
![](https://i-blog.csdnimg.cn/blog_migrate/56e99c0a7b14d793c97625cac199d24f.png)
注:也可以尝试码友Victor_Gui提出的解决方案:https://blog.csdn.net/qq_31829611/article/details/90200694