tensorflow的数据读取 tf.data.DataSet、tf.data.Iterator

tensorflow的工程有使用python的多进程读取数据,然后给feed给神经网络进行训练。

也有tensorflow中的 tf.data.DataSet的使用。并且由于是tensorflow框架的内容,会让工程看起来更加连贯流畅。

这里我们需要先了解 tf.data 下的两个类:

  • tf.data.DataSet:将我们的numpy数据 转换成 tensorflow的DataSet数据
  • tf.data.Iterator:生成DataSet的迭代器,来源源不断的获取数据传送送神经网络。

接下来我们用例子说明:

import tensorflow as tf
import numpy as np

data_num = 100
batch = 5

def dealdata(image, label):
    """该函数下使用的是python的语法进行处理数据
    注意:返回数据基本需要加上 .astype(np.float32) 或 .astype(np.int32),否则可能会类型问题上报错,"""
    image = image * 2
    label = label * 2
    return image.astype(np.int32), label.astype(np.int32)

def map_func(image, label):
    """
    该函数下的内容是要使用tensorflow的语法实现的数据处理。
    如果数据处理比较复杂,可以使用tf.py_func调用python的处理函数
    这里 tf.py_func的入参最后一项,数据的类型和个数,是根据调用的函数具体返回数据填写的,不必与get_batch_gen保持一致!!!
    """
    image = image + 1
    image, label = tf.py_func(dealdata, [image, label], [tf.int32, tf.int32])
    return image, label


def get_batch_gen(split):
    """
    该函数返回的内容,是 tf.data.Dataset.from_generator()函数的三个入参
    Returns:
        batch_gen: 数据信息的迭代器的函数,
        gen_types:tf.data.Dataset.from_generator()迭代出的数据的类型
        gen_shape:f.data.Dataset.from_generator()迭代出的数据的shape
    """

    if (split == "train"):
        image = list(range(0, data_num))  # 这里用数字代表输入的数据信息,实际使用可以是数据的路径等
        label = list(range(0, data_num))  # 使用数字代表数据的标签
    if (split == "val"):
        # image = ...
        # label = ...
        print()

    def batch_gen():
        for i in range(data_num):  # 这里循环的次数,是网络训练的次数(具体前向传播的次数)
            yield ([image[i]], [label[i]])

    gen_types = (tf.int32, tf.int32)
    gen_shape = ([None], [None])
    return batch_gen, gen_types, gen_shape


class Dataset:
    def init_input_pipeline(self):
        gen_function, gen_types, gen_shapes = get_batch_gen("train")

        """将python定义的 数据信息的迭代器的内容,转换成tensorflow的 DataSet,该 DataSet有两个属性:.output_shapes/.output_types"""
        data = tf.data.Dataset.from_generator(gen_function, gen_types, gen_shapes)
        # print(data.output_shapes)
        # print(data.output_types)

        """设置了输出数据的batch。虽然返回的data 是tensorflow中另外一种类 BatchDataSet,其实就是设置batch"""
        # data = data.batch(batch)
        data = data.shuffle(data_num).batch(batch)  # 读取出了 data_num 个数据后,就打乱。
        data = data.map(map_func=map_func, num_parallel_calls=4)  # 进行多进程数量,来并行处理数据。map_func为处理函数
        data = data.prefetch(buffer_size=batch * 10)  # 提前准备数据的数量,提前处理了数据,就可以让gpu尽可能少的处于等待数据的状态

        """构造一个DataSet的迭代器"""
        iter = tf.data.Iterator.from_structure(data.output_types, data.output_shapes)
        self.init_op = iter.make_initializer(data)  # 迭代器的初始化
        self.flat_inputs = iter.get_next()  # 使用get_next源源不断的获取数据
        """需要说明的,这里的self.flat_inputs,就可以想tf.placehoder一样,直接作为神经网络的输入节点,在其后面一层一层的定义网络结构。"""


class network():
    def __init__(self, flat_inputs):
        self.input = {}
        self.input["image"] = flat_inputs[0]  # 类比于输入的placehoder的地位
        self.input["label"] = flat_inputs[1]  # 类比于label的placehoder的地位

        # ... 具体神神经网络结构
        self.out = ...


if __name__ == '__main__':

    trian_Dataset = Dataset()
    trian_Dataset.init_input_pipeline()

    # model = network(trian_Dataset.flat_inputs)  # 初始化神经网络
    with tf.Session() as sess:
        for i in range(2):
            sess.run(trian_Dataset.init_op)
            try:
                while True:
                    a, b = sess.run(trian_Dataset.flat_inputs)  # 直接运行 数据读取的get_next() 节点,即可获取到数据。
                    print(a.reshape(-1), b.reshape(-1))
                    # out = sess.run(model.out)  # 运行网络的输出节点,即可得到网络输出数据

            except tf.errors.OutOfRangeError:
                # 迭代器数据取完时,直接跳转在这里,防止运行中断
                print("outOfRange")

其中,最固定化的流程是 dataset.init_input_pipeline()。

  • 主要根据【python的迭代器】定义一个【tf.data.DataSet】,然后设置【batch】,再通过【map】来处理多进程处理数据(复杂的数据处理可以通过tf.py_func来调用python语法的处理函数),并使用【prepare】设置提前准备数据的个数,使gpu尽量少的处于等待状态。
  • 定义 tf.data.DataSet 的迭代器 【tf.data.Iterator】,设置初始化节点【make_initializer】以及获取数据节点【get_next】

上面的例子的打印结果如下:

关于自己的测试和问题:

本人使用以上的代码来进行数据读取,在map_func函数中,添加【time.slaeep(2)】,模拟处理复杂数据耗时;然后在map的入参中设置了多进程的数量。然后测试发现,数据获取运行时间并未得到减少。

另外在github上下载论文开源工程,修改map的多进程数量入参,测试用时还是相同的情况。

这里我挠挠脑袋,有点疑惑不知道为什么。这里就先记录到这,没准后面那次再测其他工程就OK了

上面的形式,在我们阅读源码网络结构时是不方便的,直接打印神经网络的某一层,得到tensor的shape基本是 (?,?,?...)的形式,这不是我们需要看到的。那么阅读网络代码时,如何显示化的tensor的shape呢?如下代码,

    ...
    dataset.init_input_pipeline()
    #
    with tf.Session() as sess:
        sess.run(dataset.train_init_op)
        a = sess.run(dataset.flat_inputs)
        for j in range(len(a)):
            dataset.flat_inputs[j].set_shape(list(a[j].shape))
    #
    model = Network(dataset, cfg)
    ...

如果需要使用tf.placehoder()代替输入的位置,继而能够进行另外的操作:

    ...
    dataset.init_input_pipeline()
    #
    PlaceHolder = []
    with tf.Session() as sess:
        sess.run(dataset.train_init_op)
        a = sess.run(dataset.flat_inputs)
        for j in range(len(a)):
            PlaceHolder.append(tf.placeholder(dataset.flat_inputs[j].dtype, a[j].shape, "pl_{}".format(j)))
            print(PlaceHolder[-1])
    dataset.flat_inputs = PlaceHolder
    #
    model = Network(dataset, cfg)
    ...
 


 

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
tf.data.Dataset是TensorFlow中用于处理数据的模块,它提供了一种高效且易于使用的数据输入方式,可以处理大量的数据并且可以轻松地与TensorFlow中的其他模块集成。 使用tf.data.Dataset有以下几个步骤: 1. 创建Dataset对象 可以通过多种方式创建Dataset对象,例如: - 从Tensor创建:tf.data.Dataset.from_tensor_slices(tensor) - 从numpy数组创建:tf.data.Dataset.from_tensor_slices(numpy_array) - 从文件创建:tf.data.Dataset.from_tensor_slices(file_paths) 2. 对数据进行转换和处理 Dataset对象可以应用多种转换和处理函数,例如: - map():对每个元素应用一个函数 - filter():根据条件过滤数据 - batch():将数据分成小批次 - shuffle():随机打乱数据 可以通过链式调用这些函数来对数据进行处理和转换。 3. 创建迭代器 可以使用Dataset对象的make_one_shot_iterator()方法创建一个迭代器,该迭代器将按顺序遍历Dataset对象中的每个元素。也可以使用make_initializable_iterator()方法创建一个可初始化的迭代器,需要在使用迭代器之前调用迭代器的initialize()方法初始化。 4. 使用迭代器读取数据 可以使用迭代器的get_next()方法获取下一个元素。在使用Session运行TensorFlow图时,可以将get_next()方法的结果作为feed_dict的值传递给模型。 示例代码: ```python import tensorflow as tf # 创建Dataset对象 data = tf.data.Dataset.from_tensor_slices(tf.range(10)) # 对数据进行转换和处理 data = data.filter(lambda x: x % 2 == 0) data = data.map(lambda x: x * 2) data = data.shuffle(buffer_size=10) data = data.batch(batch_size=2) # 创建迭代器 iterator = data.make_initializable_iterator() # 使用迭代器读取数据 with tf.Session() as sess: sess.run(iterator.initializer) while True: try: batch = sess.run(iterator.get_next()) print(batch) except tf.errors.OutOfRangeError: break ``` 此代码将创建一个包含数字0到9的Dataset对象,并对其进行过滤、映射、随机打乱和分批处理。然后创建一个可初始化的迭代器,并使用Session运行TensorFlow图来逐批处理数据

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值