前言
ResNet是一个比较成熟的深度学习分类模型,目前有ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152,同时,该分类模型常用于RGB(三通道)彩色图像的分类任务,如在ImageNet的训练;而在单通道图像(灰度图像)的训练和测试较少。如何使ResNet在单通道图像上训练,如何修改网络模型参数和读取图像,本文将一一进行讲解。
步骤
第一步:构建数据集
- 数据集的结构应该是这样的
- 图像的格式:8bit,jpg格式
第二步:修改网络模型
- 法1:直接修改定义的ResNet网络模型
在model.py中,修改ResNet的第一层卷积层输入通道为1(彩色为3)
self.conv1 = nn.Conv2d(1, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)
- 法2:在train.py文件中,进行如下修改,也可以达到法1的效果
model = resnet18(num_classes=3)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model = model.to(device)
第三步:修改读取数据方式
-
一般我们用torchvision.datasets.ImageFolder()读取数据,但在读取单通道数据时,此函数会自动将单通道图像转换为三通道图像(r=g=b),此时如果不进行其他操作,就会报错
-
这是ImageFolder()函数的定义:留意读取的图像为PIL图像,且会转换为RGB格式
-
修改方法:
修改transform(图像预处理操作)
添加transforms.Grayscale(1),将图像转换为单通道图像(经实验,图像矩阵的数据并不会发生变化)
transforms.Normalize修改如下,第一个参数为mean,第二个参数为std,因为是单通道,所以进行Z-Score时仅需要对一个通道进行操作
data_transforms = {
'train': transforms.Compose([
transforms.Grayscale(1),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, ], [0.229, ])
]),
'val': transforms.Compose([
transforms.Grayscale(1),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, ], [0.229, ])
])
}
第四步:训练分类网络并测试(注意测试时transform与‘val’方式一样)
总结
ResNet训练单通道主要修改两个部分,一个是ResNet模型第一层卷积层的in_channels=1,另一个是transform中添加Grayscale(1)以及修改Normalize。其实很简单,只是有时忽略了ImageFolder会自动将灰度图转换为RGB图,导致出错,希望本文能帮助您!
参考资料:
可以参考这位up主github里面的Test5_resnet,并在此基础上进行上述修改,训练自己的灰度图像!
https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_learning