【tf.keras.utils.Sequence】构建自己的数据集生成器

every blog every motto: You can do more than you think.

0. 前言

在训练模型时,我们往往不一次将数据全部加载进内存中,而是将数据分批次加载到内存中。


  • 一种方法是用 while True 遍历数据,用yeid产生,具体可参考语义分割代码讲解部分
  • 另一种方法是本文即将讲解的tf.keras.utils.Sequence方法

1. 正文

1.1 基础用法

__ len __ 中返回的即1个epoch迭代的次数,即:
总样本数/ batch_size

__ getitem __ 根据len中的迭代次数,生成数据


注意: __ len __ ,__ getitem __ 必须要实现

"""
测试
__getitem__
"""
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf


class Date(tf.keras.utils.Sequence):

    def __init__(self):
        print('初始化相关参数')

    def __len__(self):
        """
        此方法要实现,否则会报错
        正常程序中返回1个epoch迭代的次数
        :return:
        """
        return 5

    def __getitem__(self, index):
        """生成一个batch的数据"""
        print('index:', index)
        x_batch = ['x1', 'x2', 'x3', 'x4']
        y_batch = ['y1', 'y2', 'y3', 'y4']
        print('-'*20)
        return x_batch, y_batch


# 实例化数据
date = Date()

for batch_number, (x, y) in enumerate(date):
    print('正在进行第{} batch'.format(batch_number))
    print('x_batch:', x)
    print('y_batcxh:', y)

结果:
在这里插入图片描述

1.2 扩展(2020.11.12 15:37增补)

可以在类中实现on_epoch_end方法,保证在每个epoch后打乱原有数据的顺序

1.2.1 训练样例:

测试代码,如下:

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np

print('tensorflow version: ', tf.__version__)


class ZerosFirstEpochOnesAfter(tf.keras.utils.Sequence):
    def __init__(self):
        self.shuffle = True

    def __len__(self):
        return 2

    def on_epoch_end(self):
        print('---------------on_epoch_end------------')

        # 打乱索引
        # if self.shuffle:
        #     print('==============================================================shuffle')
        #     np.random.shuffle(self.indices)

    def __getitem__(self, item):

        return np.zeros((16, 1)), np.zeros((16,))



def main():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(1, input_dim=1, activation="softmax"))

    model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']
                  )

    model.fit(ZerosFirstEpochOnesAfter(), epochs=3, )


if __name__ == '__main__':
    main()

tensorflow 2.0:
在这里插入图片描述

tensorflow 2.1:

在这里插入图片描述

tesorflow 2.3:
在这里插入图片描述
由以上三个版本的训练结果,我们可以发现,

  • 在2.0和2.1版本中,是没有进行on_epoch_end方法调用的,即没有实现on_epoch_end方法内注释部分的打乱顺序,这是tensorflow早期版本的一个bug,具体可参考文后第4个链接。
  • 在2.3版本中已得到改进

1.2.2 循环遍历:

1.2.2.1 原始版测试

循环遍历,如下所示:

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np

print('tensorflow version: ', tf.__version__)

class Date(tf.keras.utils.Sequence):

    def __init__(self):
        print('初始化相关参数')
        self.lines = [1,2,3,4,5]
        self.shuffle = True

    def __len__(self):
        """
        此方法要实现,否则会报错
        正常程序中返回1个epoch迭代的次数
        :return:
        """
        return 2

    def on_epoch_end(self):
        print('=======================')
        if self.shuffle == True:
            print('------------一个epoch结束,打乱了顺序---')
            np.random.shuffle(self.lines)

    def __getitem__(self, index):
        """生成一个batch的数据"""
        print('index:', index)
        x_batch = ['x1', 'x2', 'x3', 'x4']
        y_batch = ['y1', 'y2', 'y3', 'y4']
        print('-' * 20)
        return x_batch, y_batch


# 实例化数据
date = Date()

for epoch in range(2):
    for batch_number, (x, y) in enumerate(date):
        print('正在进行第{} batch'.format(batch_number))
        print('x_batch:', x)
        print('y_batcxh:', y)
    print('一个epoch结束=============================')

结果:
在这里插入图片描述
如上图所示,通过循环遍历这种方法仍然不能调用on_epoch_end,即无法打乱顺序

1.2.2.2 改进版
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np

print('tensorflow version: ', tf.__version__)


class Date(tf.keras.utils.Sequence):

    def __init__(self):
        print('初始化相关参数')
        self.lines = [1,2,3,4,5]
        self.shuffle = True

    def __len__(self):
        """
        此方法要实现,否则会报错
        正常程序中返回1个epoch迭代的次数
        :return:
        """
        return 2

    def on_epoch_end(self):
        print('=======================')
        if self.shuffle == True:
            print('------------一个epoch结束,打乱了顺序---')
            np.random.shuffle(self.lines)

    def __getitem__(self, index):
        """生成一个batch的数据"""
        print('index:', index)
        x_batch = ['x1', 'x2', 'x3', 'x4']
        y_batch = ['y1', 'y2', 'y3', 'y4']
        print('-' * 20)
        return x_batch, y_batch


# 实例化数据
date = Date()

for epoch in range(2):
    print(date.lines)
    for batch_number, (x, y) in enumerate(date):
        print('正在进行第{} batch'.format(batch_number))
        print('x_batch:', x)
        print('y_batcxh:', y)
    np.random.shuffle(date.lines)
    print('一个epoch结束=============================')

如下图所示,我们发现已经打乱了“样本”顺序,
在这里插入图片描述

参考文献

[1] https://blog.csdn.net/weixin_39190382/article/details/105808830
[2] https://blog.csdn.net/weixin_43198141/article/details/89926262
[3] https://blog.csdn.net/u011311291/article/details/80991330
[4] https://github.com/tensorflow/tensorflow/issues/35911
[5] https://colab.research.google.com/gist/bfs15/fd18263f788a071225c60cedaf126748/35911.ipynb

### 回答1: tf.keras.utils.image_dataset_from_directory是一个函数,用于从目录中读取图像数据集并返回一个tf.data.Dataset对象。它可以自动将图像数据集划分为训练集和验证集,并对图像进行预处理和数据增强。此函数是TensorFlow Keras API的一部分,用于构建深度学习模型。 ### 回答2: tf.keras.utils.image_dataset_from_directory是一个用于从文件夹中加载图像数据集的实用函数。该函数以指定的文件夹路径作为输入,自动将文件夹中的图像按照类别划分,并生成一个tf.data.Dataset对象,用于训练或评估深度学习模型。 该函数的主要参数包括: - directory:指定的文件夹路径,用于加载图像数据集。 - labels:可选参数,指定是否从文件夹的子文件夹中自动提取类别标签。 - label_mode:可选参数,指定类别标签的返回类型。支持"categorical"、"binary"、"sparse"和"int"四种类型。 - batch_size:指定生成的Dataset对象中每个batch的样本数量。 - image_size:可选参数,指定生成的样本的图像大小。 - validation_split:可选参数,指定用于验证集划分的比例。 当调用该函数时,首先会通过遍历指定路径下的所有图片文件,自动提取所有类别的名称。然后,根据提取的类别信息,将文件夹中的图像按照类别划分,并为每个类别生成一个不同的整数标签。最后,将这些划分好的图像数据转换为tf.data.Dataset对象,并将类别标签与样本数据一一对应。 最终生成的Dataset对象中,每个样本都是一个元组,包含图像数据和对应的类别标签。该Dataset对象可以直接用于训练或评估深度学习模型,并且可以通过设置参数来自动进行数据增强、批处理等操作。 使用tf.keras.utils.image_dataset_from_directory函数,可以方便地加载和处理大量的图像数据集,提高模型训练的效率和准确率。 ### 回答3: tf.keras.utils.image_dataset_from_directory是一个用于从文件目录中加载图像数据集的函数。它基于TensorFlow的Keras API,并提供了一种方便的方式来准备图像数据集进行训练和验证。 该函数能够自动地从文件目录读取图像,并创建一个TensorFlow数据集对象,其中每个图像与其标签关联。使用该函数,可以轻松地从文件夹中加载具有不同类别的图像数据,并自动将其划分为训练集和验证集。可以指定训练集和验证集的比例、图像的大小、批次大小等参数。此外,还可以进行数据预处理操作,如图像放缩、归一化等。 使用该函数的步骤如下: 1. 准备图像数据集:将不同类别的图像按照标签存储在不同的文件夹中。 2. 调用image_dataset_from_directory函数:指定图像文件夹的路径,并设定其他参数如图像大小、批次大小等。 3. 接收返回的数据集对象:该对象包含训练集和验证集。 4. 可以将该数据集对象直接用于模型的训练、评估和推理。 该函数的优点是简单易用,能够快速地加载图像数据集,并且能够与tf.data API无缝集成,方便进行数据增强、数据流水线等高级操作。它减少了手动处理图像数据的工作量,使得图像分类、目标检测等任务更加高效。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

胡侃有料

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值