Jupyter 中pytorch 报错output with shape [1, 28, 28] doesn’t match the broadcast shape [3, 28, 28]
错误提示
RuntimeError: output with shape [1, 28, 28] doesn’t match the broadcast shape [3, 28, 28]
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
])
trainset=datasets.MNIST('MNIST_data/',transform=transform,train=True)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)
dataiter = iter(trainloader)
images,lables=dataiter.next()
报错截图
解决方法:
在 transform中增加一行代码,改为
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3,1,1)),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
即可