卷积神经网络<三>Keras多通道数据生成器

数据生成器就是一个随机读取文件的过程
数据生成器的作用是输入的数据太大,无法一次放入内存中的时候,就需要一个batch一个batch的读取。
keras的数据生成器最常用的办法是继承ut.Sequence这个类,然后生成一个数据生成器类。
使用的时候可以直接使用fit_generator()函数进行

这里写目录标题

核心方法

其他的都可以不用改变,修改这个读取文件的逻辑就行了。这里可以通过参数n_image,读取一个、两个或者三个通道的图片作为输入。也可以把文件存成Npy形式,读取三个npy文件。
核心方法就是文件的读取。 把每一个特征图存储下来就行了。

 # 核心的方法
    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X_rri = np.empty((self.batch_size, *self.dim, self.n_channels))
        X_amp = np.empty((self.batch_size, *self.dim, self.n_channels))
        X_edr = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            rri = cv2.imread(self.data_path + "rri/" + ID + ".png").astype("float32") / 255
            amp = cv2.imread(self.data_path + "amp/" + ID + ".png").astype("float32") / 255
            edr = cv2.imread(self.data_path + "edr/" + ID + ".png").astype("float32") / 255
            X_rri[i, ] = rri
            X_amp[i, ] = amp
            X_edr[i, ] = edr

            # Store class,
            y[i] = self.labels[ID]
        ret = []
        if self.n_image == 1:
            ret = [X_rri]
        elif self.n_image == 2:
            ret = [X_rri, X_edr]
        else :
            ret = [X_rri, X_edr, X_amp]
        return ret, ut.to_categorical(y, num_classes=self.n_classes)

定义代码

"""
@author:fuzekun
@file:Data_Gegerator.py
@time:2022/11/19
@description:

1. 需要确定生成数据的ID是否是从0开始的 index = 0 * index


1. 随机生成需要读取的id
2. 初始化一个空列表,然后将图片读取出来
3. 返回读取的图片和one hot之后的label


注意 import keras.utils.all_utils as ut
不要直接使用utils了,会报错,版本问题

注意返回三个numpy文件
"""
import cv2
import numpy as np
import keras
import keras.utils.all_utils as ut

class DataGenerator(ut.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, labels, data_path, batch_size=32, dim=(2,16,16), n_image = 3, n_channels=3,
                 n_classes=2, shuffle=True):
        'Initialization'
        self.data_path = data_path      # 数据的路径
        self.dim = dim                  # 每一个文件的维度
        self.batch_size = batch_size
        self.labels = labels
        self.list_IDs = list_IDs
        self.n_image = n_image          # 输入图片的数量rri, 还是rri + edr,还是三个
        self.n_channels = n_channels    # 图的通道,3通道就是3
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):       # 初始是从0是开的
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    # 核心的方法
    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X_rri = np.empty((self.batch_size, *self.dim, self.n_channels))
        X_amp = np.empty((self.batch_size, *self.dim, self.n_channels))
        X_edr = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            rri = cv2.imread(self.data_path + "rri/" + ID + ".png").astype("float32") / 255
            amp = cv2.imread(self.data_path + "amp/" + ID + ".png").astype("float32") / 255
            edr = cv2.imread(self.data_path + "edr/" + ID + ".png").astype("float32") / 255
            X_rri[i, ] = rri
            X_amp[i, ] = amp
            X_edr[i, ] = edr

            # Store class,
            y[i] = self.labels[ID]
        ret = []
        if self.n_image == 1:
            ret = [X_rri]
        elif self.n_image == 2:
            ret = [X_rri, X_edr]
        else :
            ret = [X_rri, X_edr, X_amp]
        return ret, ut.to_categorical(y, num_classes=self.n_classes)

调用代码

partition_train = [str(x) for x in range(len(train_label))]
partition_test = [str(y) for y in range(len(test_label))]
train_label = dict((str(x), train_label[x]) for x in range(len(train_label)))
test_label = dict((str(x), test_label[x]) for x in range(len(test_label)))


# 注意generator了几个个numpy要和input的图片对应
train_generator = DataGenerator(
    list_IDs=partition_train,
    data_path=train_data_path,
    labels=train_label,
    batch_size=batch_size,
    n_image=n_image,
    dim=img_shap,  # 3张图, 16 * 16
    n_channels=n_chanels,  # 三种颜色
    n_classes=2,
    shuffle=False
)
test_generator = DataGenerator(
    list_IDs=partition_test,
    data_path=test_data_path,
    labels=test_label,
    batch_size=batch_size,
    n_image=n_image,
    dim=img_shap,  # 图片的大小 16 * 16
    n_channels=n_chanels,  # 三种颜色
    n_classes=2,
    shuffle=False
)
print("--------------------------------------")
print("生成器定义完成, 注意生成器中的label应该是字典")
print("--------------------------------------")

print("-----------------------------------")
print("开始训练")
print("-----------------------------------")
model.fit_generator(
    generator = train_generator,
    epochs = training_epochs,
    validation_data = test_generator,
    use_multiprocessing = False,
   # max_queue_size=4,
    callbacks = callbacks,
    verbose = 1
)
print("---------------------------------------")
print(" 训练完成 ")
print("---------------------------------------")

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值