在识别手写数字时,出现维度不匹配的情况
print(iter(train_loader).next()[0].size())
其实并不是这行代码的问题,而是trandform的问题:
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
我们要注意的是MINIST的图片是灰度图,chanel为1,在transform中标准化的时候改为:
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)])
就可以正常运行