tf.data.Dataset与tfrecord学习笔记

目录

1.tf.data.Dataset

2.tfrecord

2.1 使用tfrecord的原因

2.2 tfrecord的写入

2.3 tfrecord的读取

3.两种方式的区别

参考资料:


1.tf.data.Dataset

# 从tensor中获取数据

dataset = tf.data.Dataset.from_tensor_slices(img_paths)

 

# 可选项,从数据集中过滤数据

dataset = dataset.filter(filter)

 

# 数据解析,原来可能是路径,需要变成真实的图片数据

# 其中num_parallel_calls表示并行操作的线程数量,一般设置为CPU核心数量为最好

dataset = dataset.map(map_func, num_parallel_calls=num_threads)

 

# 打乱数据,这里有个buffer_size,表示每次从这个buffer_size个数据中随机一个位置,与buffer_size外的数据进行交换

dataset = dataset.shuffle(buffer_size)

 

# 把数据集组装成batchs

if drop_remainder:

    # 将dataset切分成n个batch_size,并且决定是否丢掉最后一个不满足一个batch的数据

    dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))

else:

    dataset = dataset.batch(batch_size)

 

# repeat表示重复的次数,-1表示重复无限次,这样就永远不会报outOfRange这种错

"""

tf.data.Dataset.prefetch 提供了 software pipelining 机制。该函数解耦了 数据产生的时间 和 数据消耗的时间。

具体来说,该函数有一个后台线程和一个内部缓存区,在数据被请求前,就从 dataset 中预加载一些数据(进一步提高性能)。

prefech(n) 一般作为最后一个 transformation,其中 n 为 batch_size。

因为数据已经分成了好几个batch,那么这句话其实就是在数据被请求之前,预加载2个batch的数据

"""

dataset = dataset.repeat(repeat).prefetch(prefetch_batch)

 

# 建立一个迭代器

iterator = dataset.make_initializable_iterator()

batch_data = iterator.get_next()

 

# 建立一个会话Session

with tf.Session as sess:

    # 初始化迭代器

    sess.run(iterator.initializer)

    data = sess.run(batch_data)

实测代码:

import os


import numpy as np
import tensorflow as tf
from tflib.utils import session
import random


img_paths = "E:\\python_project\\DeeCamp\\data\\list_attr_celeba.txt"
buffer_size = 4096
drop_remainder = True
batch_size = 32
repeat = -1
prefetch_batch = 2


names = np.loadtxt(img_paths, skiprows=2, usecols=[0], dtype=np.str)


print("start")
# 从tensor中获取数据
dataset = tf.data.Dataset.from_tensor_slices(names)


print("read files over")


# 可选项,从数据集中过滤数据
# dataset = dataset.filter()


# 数据解析,原来可能是路径,需要变成真实的图片数据
# 其中num_parallel_calls表示并行操作的线程数量,一般设置为CPU核心数量为最好
# dataset = dataset.map(map_func, num_parallel_calls=num_threads)


# 打乱数据,这里有个buffer_size,表示每次从这个buffer_size个数据中随机一个位置,与buffer_size外的数据进行交换
dataset = dataset.shuffle(buffer_size)


# 把数据集组装成batchs
if drop_remainder:
    # 将dataset切分成n个batch_size,并且决定是否丢掉最后一个不满足一个batch的数据
    dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
    dataset = dataset.batch(batch_size)




# repeat表示重复的次数,-1表示重复无限次,这样就永远不会报outOfRange这种错
"""
tf.data.Dataset.prefetch 提供了 software pipelining 机制。该函数解耦了 数据产生的时间 和 数据消耗的时间。
具体来说,该函数有一个后台线程和一个内部缓存区,在数据被请求前,就从 dataset 中预加载一些数据(进一步提高性能)。
prefech(n) 一般作为最后一个 transformation,其中 n 为 batch_size。
因为数据已经分成了好几个batch,那么这句话其实就是在数据被请求之前,预加载2个batch的数据
"""
dataset = dataset.repeat(repeat).prefetch(prefetch_batch)




# 建立一个迭代器
iterator = dataset.make_initializable_iterator()
batch_data = iterator.get_next()


# 建立一个会话Session
with tf.Session() as sess:
    # 初始化迭代器
    sess.run(iterator.initializer)
    for i in range(10):
        data = sess.run(batch_data)
        print(data)
        print("="*20)

结果如下:

[b'002411.jpg' b'001866.jpg' b'003657.jpg' b'000849.jpg' b'002705.jpg'

b'002485.jpg' b'000120.jpg' b'002057.jpg' b'003620.jpg' b'003092.jpg'

b'003111.jpg' b'000557.jpg' b'003030.jpg' b'001831.jpg' b'001967.jpg'

b'000258.jpg' b'002366.jpg' b'004102.jpg' b'000067.jpg' b'002444.jpg'

b'003847.jpg' b'003876.jpg' b'000516.jpg' b'002107.jpg' b'003941.jpg'

b'004006.jpg' b'000632.jpg' b'000080.jpg' b'002286.jpg' b'003046.jpg'

b'000785.jpg' b'001122.jpg']

====================

[b'001681.jpg' b'001067.jpg' b'002432.jpg' b'002173.jpg' b'001478.jpg'

b'000610.jpg' b'001715.jpg' b'002695.jpg' b'004003.jpg' b'004100.jpg'

b'002240.jpg' b'000286.jpg' b'003298.jpg' b'000760.jpg' b'003712.jpg'

b'003076.jpg' b'003598.jpg' b'000423.jpg' b'003211.jpg' b'002405.jpg'

b'001274.jpg' b'003872.jpg' b'004079.jpg' b'000486.jpg' b'004012.jpg'

b'003247.jpg' b'001156.jpg' b'004073.jpg' b'002359.jpg' b'000636.jpg'

b'000349.jpg' b'001392.jpg']

====================

[b'001704.jpg' b'001051.jpg' b'002887.jpg' b'003227.jpg' b'000357.jpg'

b'003706.jpg' b'003297.jpg' b'004016.jpg' b'002112.jpg' b'002975.jpg'

b'004077.jpg' b'002272.jpg' b'001991.jpg' b'000694.jpg' b'001515.jpg'

b'000242.jpg' b'002169.jpg' b'003926.jpg' b'001462.jpg' b'002646.jpg'

b'003214.jpg' b'000487.jpg' b'000326.jpg' b'001344.jpg' b'001069.jpg'

b'003025.jpg' b'002724.jpg' b'002502.jpg' b'002479.jpg' b'004098.jpg'

b'001749.jpg' b'003203.jpg']

====================

[b'003235.jpg' b'003145.jpg' b'000356.jpg' b'003175.jpg' b'001426.jpg'

b'003209.jpg' b'004105.jpg' b'002073.jpg' b'003118.jpg' b'003629.jpg'

b'001634.jpg' b'003120.jpg' b'000098.jpg' b'001096.jpg' b'001607.jpg'

b'003158.jpg' b'004115.jpg' b'000084.jpg' b'003362.jpg' b'003666.jpg'

b'001573.jpg' b'002369.jpg' b'002097.jpg' b'003621.jpg' b'003484.jpg'

b'003809.jpg' b'001107.jpg' b'001207.jpg' b'003556.jpg' b'003763.jpg'

b'003594.jpg' b'001101.jpg']

====================

[b'000073.jpg' b'003798.jpg' b'002839.jpg' b'002614.jpg' b'002921.jpg'

b'002453.jpg' b'003261.jpg' b'002648.jpg' b'002605.jpg' b'003388.jpg'

b'003010.jpg' b'000752.jpg' b'003783.jpg' b'001673.jpg' b'002732.jpg'

b'002936.jpg' b'001997.jpg' b'003518.jpg' b'001005.jpg' b'002789.jpg'

b'001082.jpg' b'003087.jpg' b'003873.jpg' b'003871.jpg' b'001441.jpg'

b'003494.jpg' b'000135.jpg' b'001564.jpg' b'000410.jpg' b'002700.jpg'

b'001258.jpg' b'003723.jpg']

====================

[b'000948.jpg' b'003301.jpg' b'003280.jpg' b'001173.jpg' b'002086.jpg'

b'001553.jpg' b'001125.jpg' b'003796.jpg' b'002469.jpg' b'000866.jpg'

b'003491.jpg' b'003708.jpg' b'004152.jpg' b'001616.jpg' b'003965.jpg'

b'002069.jpg' b'002966.jpg' b'000739.jpg' b'001433.jpg' b'000419.jpg'

b'001955.jpg' b'003578.jpg' b'003493.jpg' b'000992.jpg' b'001333.jpg'

b'004042.jpg' b'003442.jpg' b'001623.jpg' b'003615.jpg' b'004140.jpg'

b'003635.jpg' b'000619.jpg']

====================

[b'002193.jpg' b'002691.jpg' b'000456.jpg' b'002500.jpg' b'001423.jpg'

b'003624.jpg' b'002149.jpg' b'000743.jpg' b'001570.jpg' b'002141.jpg'

b'002891.jpg' b'000467.jpg' b'002985.jpg' b'003384.jpg' b'003971.jpg'

b'003143.jpg' b'001541.jpg' b'003032.jpg' b'002317.jpg' b'003951.jpg'

b'001980.jpg' b'000183.jpg' b'002111.jpg' b'001115.jpg' b'000163.jpg'

b'000381.jpg' b'004301.jpg' b'001529.jpg' b'002506.jpg' b'003976.jpg'

b'003886.jpg' b'002414.jpg']

====================

[b'001126.jpg' b'000007.jpg' b'002410.jpg' b'002568.jpg' b'003724.jpg'

b'002947.jpg' b'003988.jpg' b'004004.jpg' b'002682.jpg' b'003284.jpg'

b'000003.jpg' b'001234.jpg' b'001080.jpg' b'002395.jpg' b'000085.jpg'

b'002064.jpg' b'000646.jpg' b'003652.jpg' b'004264.jpg' b'000577.jpg'

b'004320.jpg' b'003726.jpg' b'003859.jpg' b'001369.jpg' b'001056.jpg'

b'003422.jpg' b'003193.jpg' b'001178.jpg' b'000918.jpg' b'000509.jpg'

b'000296.jpg' b'003273.jpg']

====================

[b'001889.jpg' b'003185.jpg' b'000029.jpg' b'002218.jpg' b'001762.jpg'

b'003392.jpg' b'002634.jpg' b'001382.jpg' b'001100.jpg' b'000779.jpg'

b'000544.jpg' b'003537.jpg' b'002630.jpg' b'004138.jpg' b'000539.jpg'

b'002091.jpg' b'000378.jpg' b'002754.jpg' b'002377.jpg' b'002861.jpg'

b'002858.jpg' b'003162.jpg' b'002898.jpg' b'004361.jpg' b'003058.jpg'

b'001686.jpg' b'000629.jpg' b'002349.jpg' b'001722.jpg' b'002675.jpg'

b'002903.jpg' b'001424.jpg']

====================

[b'000194.jpg' b'001345.jpg' b'003553.jpg' b'003031.jpg' b'001821.jpg'

b'003232.jpg' b'000852.jpg' b'003112.jpg' b'000841.jpg' b'002522.jpg'

b'004280.jpg' b'000538.jpg' b'000830.jpg' b'003934.jpg' b'003596.jpg'

b'002004.jpg' b'000794.jpg' b'002015.jpg' b'002620.jpg' b'001974.jpg'

b'003632.jpg' b'002461.jpg' b'003142.jpg' b'000532.jpg' b'001941.jpg'

b'002232.jpg' b'000924.jpg' b'003882.jpg' b'000489.jpg' b'002766.jpg'

b'001795.jpg' b'003866.jpg']

====================

Process finished with exit code 0

总结一下,使用tf.data.Dataset的主要步骤有:

1. 使用tf.data.Dataset.from_tensor_slices从输入的tensor'中获取数据,这个tensor可以是图片的路径组成的list;

2. 使用dataset.filter选择是否过滤数据

3. 使用dataset.map解析数据,如果读入的只是路径,那么使用这个函数可以将路径对应的图片数据读进来;

4. 使用dataset.shuffle打乱数据

5. 使用dataset.batch生成batches

6. 使用dataset.repeat重复数据集

7. 使用dataset.prefetch设置在数据请求前预加载的batch数量

8. 使用iterator = dataset.make_initializable_iterator(), batch_data = iterator.get_next()建立迭代器

9. 在session中初始化迭代器sess.run(iterator.initializer),并通过迭代器的get_next()获取一个batch的数据

 

2.tfrecord

2.1 使用tfrecord的原因

    正常情况下我们训练文件夹经常会生成 train, test 或者val文件夹,这些文件夹内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时候会非常慢,繁琐。占用大量内存空间(有的大型数据不足以一次性加载)。此时我们TFRecord格式的文件存储形式会很合理的帮我们存储数据。

    TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。

2.2 tfrecord的写入

# 声明一个TFRecordWriter,才能将信息写入TFRecord文件

# 其中output表示为存储的路径,如“output.tfrecord”

writer = tf.python_io.TFRecordWriter(output)

 

# 读取图片并进行解码, input是图片路径

image = Image.open("image.jpg")

shape = image.shape

# 将图片转换成 string。

image_data = image.tostring()

name = bytes("cat", encoding='utf8')

 

# 创建Example对象,并将Feature一一填充进去

example = tf.train.Example(features=tf.train.Features(feature={

   'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),

   'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),

   'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))

    }

   ))

       

# 将 example 序列化成 string 类型,然后写入

writer.write(example.SerializeToString())

 

# 全部写完之后,关闭writer

writer.close()

总结一下,tfrecord主要分为4步:

  1. 声明TFRecordWriter;

  2. 创建Example对象,并将Feature存入;

  3. 将Example序列化成string类型写入;

  4. 关闭writer

2.3 tfrecord的读取

# 定义一个reader

reader = tf.TFRecordReader()

# 读取tfrecord文件,得到一个filename_queue【中括号必须保存下来】

filename_queue=tf.train.string_input_producer(['titanic_train.tfrecords'])

# 返回文件名和文件

_,serialized_example=reader.read(filename_queue)

# 上面的serialized_example是无法直接查看的,需要去按照特征进行解析

features = tf.parse_single_example(serialized_example,features={

    'imgae': tf.FixedLenFeature([], tf.string)

    'label': tf.FixedLenFeature([], tf.string)

})

 

# 每次将数据包装成一个batch,capacity为队列能够容纳的最大元素个数

image, label = tf.train.shuffle_batch([features['image'], features['label']], batch_size=16, capacity=500)

 

with tf.Session() as sess:

    tf.global_variables_initializer().run()

    # 创建 Coordinator 负责实现数据输入线程的同步

    coord = tf.train.Coordinator()

    # 启动队列

    threads=tf.train.start_queue_runners(sess=sess, coord)

    # 喂数据实现训练

    img, lab = sess.run([image, label])

 

    # 线程同步

    coord.request_stop()

    coord.join(threads=threads)

 

总结一下,tfrecord读取数据的步骤主要有:

  1. 使用tf.TFRecordReader()定义reader

  2. 使用tf.train.string_input_producer读取tfrecord文件

  3. 使用reader.read读取第二部返回的文件名队列

  4. 使用tf.parse_single_example解析文件名队列

  5. 使用tf.train.shuffle_batch将数据打乱并包装成一个batch

  6. 在session通过tf.train.Coordinator()实现线程同步

  7. threads=tf.train.start_queue_runners(sess=sess, coord)启动队列

  8. 通过coord.request_stop()和coord.join(threads=threads)实现线程同步

3.两种方式的区别

       tfrecord需要提前将数据存成tfrecord文件,这样可以减少每次打开文件的时间消耗,针对于大规模数据集训练模型上有帮助。但是问题在于这种方式比较死板,如果有新的数据集,就需要继续生成tfrecord文件。

 

     而tf.data.Dataset的方式就比较灵活了,采用pipeline的方式,在GPU训练数据时,CPU准备数据,不需要提前生成其他文件。

 

     总之,这两种方式都可以处理大规模数据集的训练,但个人觉得tf.data.Dataset要好用一些。

 

参考资料:

TensorFlow之tfrecords文件详细教程(https://blog.csdn.net/qq_27825451/article/details/83301811

【Tensorflow】你可能无法回避的 TFRecord 文件格式详细讲解(https://blog.csdn.net/briblue/article/details/80789608

TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制(https://blog.csdn.net/guyuealian/article/details/85106012

 

 

 

 

 

 

 

 

 

 

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值