RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

  • 在跑Pytroch的MNIST手写识别例子时,碰到了shape不匹配的错误,错误指向:
images, labels = next(iter(data_loader_train)) 
  • 在尝试过多次之后,发现错误并不是这一句引发的,而是因为图片格式是灰度图只有一个channel,需要变成RGB图才可以,所以将其中一行做了修改:
  • 修改前:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 
  • 修改后:
# 引入库
import torch
from torchvision import datasets, transforms
import torchvision.transforms
from torch.autograd import  Variable
import numpy as np
import matplotlib.pyplot as plt

transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Lambda(lambda x: x.repeat(3,1,1)),
     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
 ])   # 修改的位置

data_train=datasets.MNIST(root="./data", 
						transform=transform,
						train=True,
						download=True
                          )
data_test=datasets.MNIST(root="./data", 
						transform=transform, 
						train=False)
data_loader_train=torch.utils.data.DataLoader(dataset=data_train,
                                              batch_size=64,
                                              shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,
                                             batch_size=64,
                                             shuffle=True)

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print([labels[i] for i in range(64)])
plt.imshow(img)
  • 结果:可以看到输出的首先是64张图片对应的标签,然后是64张图片的预览结果。[tensor(8), tensor(1), tensor(7), tensor(1), tensor(8), tensor(0), tensor(6), tensor(7), tensor(1), tensor(7), tensor(1), tensor(2), tensor(5), tensor(8), tensor(5), tensor(4), tensor(3), tensor(7), tensor(8), tensor(5), tensor(1), tensor(8), tensor(3), tensor(0), tensor(8), tensor(4), tensor(2), tensor(0), tensor(9), tensor(0), tensor(6), tensor(3), tensor(9), tensor(3), tensor(6), tensor(1), tensor(1), tensor(5), tensor(2), tensor(7), tensor(0), tensor(7), tensor(4), tensor(0), tensor(1), tensor(4), tensor(8), tensor(8), tensor(7), tensor(4), tensor(5), tensor(1), tensor(2), tensor(7), tensor(3), tensor(5), tensor(1), tensor(2), tensor(7), tensor(8), tensor(2), tensor(8), tensor(4), tensor(4)]

注:也可以尝试码友Victor_Gui提出的解决方案:https://blog.csdn.net/qq_31829611/article/details/90200694

  • 71
    点赞
  • 117
    收藏
    觉得还不错? 一键收藏
  • 44
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 44
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值