我们使用datasets.ImageFolder这个类时,会将我们的数据自动转化为RGB三通道。如果输入的是灰度1通道的数据,将model.py中的channel的3修改为1,就会报错。
Traceback (most recent call last):
File "new_train.py", line 80, in <module>
logits = net(images.to(device))
File "/home/jiacong/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/jiacong/Experiments/phytoplankton-classification-grayimage/Test5_resnet/model.py", line 119, in forward
x = self.conv1(x)
File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 338, in forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 1 7 7, expected input[16, 3, 224, 224] to have 1 channels, but got 3 channels instead
改为3通道反而可以运行。
因为默认default_loader官方是 这么实现的:
def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
也就说无论你的图片是什么类型,它都会给你转成RGB图像,所以在卷积层一的输入特征矩阵channel=3并不是你以为的channel=1所以才会报错。你如果想入读gray类型的话,你需要在datasets.ImageFolder方法中传入自定的loader不让它转成RGB图像。
def my_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
img = img.convert('L')
return img
datasets.ImageFolder中传入loader=my_loader就能够成功读取灰度图像了。再提醒你一点,如果你使用transforms.Normalize方法处理灰度图像时,应该通过如下方式使用:
transforms.Normalize(mean=(0.5,), std=(0.5,))
除了以上定义loader的做法外,也可以在transforms添加transforms.Grayscale(1)也可以达到效果。