TensorFlow 2 自定义生成图像的 Generator

本文实现了一个自定义的 Generator, 从文件夹中读取图片, 然后进行45 * randn(8) 角度的旋转

def custom_generator(shuffle):
    import skimage.io
    import skimage.transform
    import numpy as np
    from pandas import DataFrame
    import os
    from tensorflow.python.keras.utils.data_utils import Sequence

    class MySequence(Sequence):
        def __init__(
                self,
                directory,
                batch_size=32,
                shuffle=True
        ):
            self.batch_size = batch_size
            categories = list(os.listdir(directory))  # 查看目录下有那些文件夹, 作为分类名
            indexes = []  # 保存索引(文件路径)
            labels = []  # 保存标签
            for category in categories:
                files = [os.sep.join([directory, category, f]) for f in os.listdir(directory + os.sep + category)]
                indexes += files
                labels += [category] * len(files)
            self.df = DataFrame({"label": labels}, index=indexes)
            self.indices = self.df.index.tolist()  # 将 DataFrame 中的文件路径索引导出, 后面要用
            self.shuffle = shuffle
            self.num_classes = len(categories)
            self.on_epoch_end()  # 每个 epoch 结束都需要重新打乱数据
            self.category_to_id = {c: i for i, c in enumerate(categories)}  # 保存类别和类别ID的映射关系

        def on_epoch_end(self):
            self.index = np.arange(len(self.indices))
            if self.shuffle:
                np.random.shuffle(self.index)

        def __len__(self):
            return len(self.indices) // self.batch_size

        def __getitem__(self, index):
            """生成一批数据"""
            # 选择对应的一批数据的索引(数字)
            index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
            # 找到上述索引对应的文件路径
            batch = [self.indices[k] for k in index]  # 取出文件名
            # 按照文件路径生成数据
            X, y = self.__get_data(batch)
            return X, y

        def __get_data(self, batch):
            # 按照文件路径生成数据
            X = batch  # 文件路径
            y = [self.df.at[b, 'label'] for b in batch]  # 对应的标签

            # 从文件读取图像并进行变换
            for i, file_name in enumerate(X):
                im = skimage.io.imread(file_name, as_gray=True)  # 读取图像, 如果是彩色图像要把 as_gray 去掉
                # 通过四周填充黑色, 调整图片大小
                im = skimage.transform.resize(im, output_shape=(200, 200), mode='constant', cval=0)
                # 旋转 45 度的倍数
                im = skimage.transform.rotate(im, np.random.randint(8) * 45)
                X[i] = im  # 将图像保存到 X 中, scikit-image 对图像的操作结果都是 numpy 矩阵
            X = np.array(X)
            y = [self.category_to_id[t] for t in y]  # 将标签转换成序号
            
            # 下面选择是用序号还是转换成 one-hot 形式
            y = np.array(y)  # 直接用序号表示分类, loss 要用 SparseCategoricalCrossentropy
            # y = np.eye(self.num_classes)[y]  # 转换成 one-hot 形式, loss 要用 CategoricalCrossentropy
            return X, y

    return MySequence(directory=r'F:\Competion\project\PET\data\images\train_crop', batch_size=32, shuffle=shuffle)

train_crop 文件夹下面有两个文件夹, CN 和 AD, 代表图像的分类, 文件夹下面是具体的图片.

下面是用一个只有全连接层的模型进行测试. 模型非常简单, 只有一个全连接层, 主要目的是检查自定义的 Generator 能不能使用.

def simple_model():
    from tensorflow import keras
    from tensorflow.keras import models
    from tensorflow.keras import layers
    model = models.Sequential()
    model.add(layers.Flatten(input_shape=[200,200]))
    model.add(layers.Dense(2))
    model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),
                  optimizer=keras.optimizers.SGD(learning_rate=1e-3),
                  metrics=['accuracy'])
    model.summary()
    history = model.fit(custom_generator(shuffle=True), epochs=10, validation_data=custom_generator(shuffle=False))
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页