使用灰度图作为数据集训练CNN

拿到一个全是灰度图的数据集,原本的模型输入的3通道的RGB图像。

1. 修改模型结构

input和output从3改成1:

    def _build_models(self):
        self.features = nn.Sequential(
            self._conv2d(1, self.hidden_size),      # 输入是灰度图,把通道改成1
            nn.LeakyReLU(inplace=True),
            nn.BatchNorm2d(self.hidden_size),
        )
        self.layers = nn.Sequential(
            self._conv2d(self.hidden_size + self.data_depth, self.hidden_size),
            nn.LeakyReLU(inplace=True),
            nn.BatchNorm2d(self.hidden_size),
            self._conv2d(self.hidden_size, self.hidden_size),
            nn.LeakyReLU(inplace=True),
            nn.BatchNorm2d(self.hidden_size),
            self._conv2d(self.hidden_size, 1),      # 输出是灰度图,把通道改成1
            nn.Tanh(),
        )
        return self.features, self.layers

注意:要把所有用到的模型的输入、输出都改了,我用到了3个模型,其中一个的输入漏改,找问题找半天T T

2. 修改数据

原本数据载入方式:

    train = DataLoader(os.path.join("data", args.dataset, "train"), shuffle=True)
    validation = DataLoader(os.path.join("data", args.dataset, "val"), shuffle=False)

直接在transform中,加入灰度转换操作,但是原本的代码中没有使用Dataset,直接载入数据的路径了。第一反应是自己写一个Dataset,但是感觉有点麻烦。
因为这不是分类任务,不需要用到label,想到之前用过的CartoonGAN里的数据加载方式,可以套过来用:
先定义一个数据加载的函数

def data_load(path, subfolder, transform, batch_size, shuffle=False, drop_last=True):
    dset = datasets.ImageFolder(path, transform)
    ind = dset.class_to_idx[subfolder]

    n = 0
    for i in range(dset.__len__()):
        if ind != dset.imgs[n][1]:
            del dset.imgs[n]
            n -= 1

        n += 1

    return torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)

再直接调用这个函数来载入数据,记得加上修改好的transform函数:

data_transform = transforms.Compose([
    transforms.Grayscale(1),  # 添加transforms.Grayscale(1),将图像转换为单通道图像
    # transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(0.485, 0.229, inplace=True),		# transforms.Normalize修改如下,第一个参数为mean,第二个参数为std,因为是单通道,所以进行Z-Score时仅需要对一个通道进行操作,所以mean和std只需要一个值就行
])

train = utils.data_load(os.path.join("data", args.dataset), "train", transform=data_transform, batch_size=2, shuffle=True, drop_last=True)
validation = utils.data_load(os.path.join("data", args.dataset), "val", transform=data_transform, batch_size=2, shuffle=True, drop_last=True)

3. 成功运行

在这里插入图片描述

评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值