目录
1.tf.data.Dataset
# 从tensor中获取数据
dataset = tf.data.Dataset.from_tensor_slices(img_paths)
# 可选项,从数据集中过滤数据
dataset = dataset.filter(filter)
# 数据解析,原来可能是路径,需要变成真实的图片数据
# 其中num_parallel_calls表示并行操作的线程数量,一般设置为CPU核心数量为最好
dataset = dataset.map(map_func, num_parallel_calls=num_threads)
# 打乱数据,这里有个buffer_size,表示每次从这个buffer_size个数据中随机一个位置,与buffer_size外的数据进行交换
dataset = dataset.shuffle(buffer_size)
# 把数据集组装成batchs
if drop_remainder:
# 将dataset切分成n个batch_size,并且决定是否丢掉最后一个不满足一个batch的数据
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
dataset = dataset.batch(batch_size)
# repeat表示重复的次数,-1表示重复无限次,这样就永远不会报outOfRange这种错
"""
tf.data.Dataset.prefetch 提供了 software pipelining 机制。该函数解耦了 数据产生的时间 和 数据消耗的时间。
具体来说,该函数有一个后台线程和一个内部缓存区,在数据被请求前,就从 dataset 中预加载一些数据(进一步提高性能)。
prefech(n) 一般作为最后一个 transformation,其中 n 为 batch_size。
因为数据已经分成了好几个batch,那么这句话其实就是在数据被请求之前,预加载2个batch的数据
"""
dataset = dataset.repeat(repeat).prefetch(prefetch_batch)
# 建立一个迭代器
iterator = dataset.make_initializable_iterator()
batch_data = iterator.get_next()
# 建立一个会话Session
with tf.Session as sess:
# 初始化迭代器
sess.run(iterator.initializer)
data = sess.run(batch_data)
实测代码:
import os
import numpy as np
import tensorflow as tf
from tflib.utils import session
import random
img_paths = "E:\\python_project\\DeeCamp\\data\\list_attr_celeba.txt"
buffer_size = 4096
drop_remainder = True
batch_size = 32
repeat = -1
prefetch_batch = 2
names = np.loadtxt(img_paths, skiprows=2, usecols=[0], dtype=np.str)
print("start")
# 从tensor中获取数据
dataset = tf.data.Dataset.from_tensor_slices(names)
print("read files over")
# 可选项,从数据集中过滤数据
# dataset = dataset.filter()
# 数据解析,原来可能是路径,需要变成真实的图片数据
# 其中num_parallel_calls表示并行操作的线程数量,一般设置为CPU核心数量为最好
# dataset = dataset.map(map_func, num_parallel_calls=num_threads)
# 打乱数据,这里有个buffer_size,表示每次从这个buffer_size个数据中随机一个位置,与buffer_size外的数据进行交换
dataset = dataset.shuffle(buffer_size)
# 把数据集组装成batchs
if drop_remainder:
# 将dataset切分成n个batch_size,并且决定是否丢掉最后一个不满足一个batch的数据
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
dataset = dataset.batch(batch_size)
# repeat表示重复的次数,-1表示重复无限次,这样就永远不会报outOfRange这种错
"""
tf.data.Dataset.prefetch 提供了 software pipelining 机制。该函数解耦了 数据产生的时间 和 数据消耗的时间。
具体来说,该函数有一个后台线程和一个内部缓存区,在数据被请求前,就从 dataset 中预加载一些数据(进一步提高性能)。
prefech(n) 一般作为最后一个 transformation,其中 n 为 batch_size。
因为数据已经分成了好几个batch,那么这句话其实就是在数据被请求之前,预加载2个batch的数据
"""
dataset = dataset.repeat(repeat).prefetch(prefetch_batch)
# 建立一个迭代器
iterator = dataset.make_initializable_iterator()
batch_data = iterator.get_next()
# 建立一个会话Session
with tf.Session() as sess:
# 初始化迭代器
sess.run(iterator.initializer)
for i in range(10):
data = sess.run(batch_data)
print(data)
print("="*20)
结果如下:
[b'002411.jpg' b'001866.jpg' b'003657.jpg' b'000849.jpg' b'002705.jpg'
b'002485.jpg' b'000120.jpg' b'002057.jpg' b'003620.jpg' b'003092.jpg'
b'003111.jpg' b'000557.jpg' b'003030.jpg' b'001831.jpg' b'001967.jpg'
b'000258.jpg' b'002366.jpg' b'004102.jpg' b'000067.jpg' b'002444.jpg'
b'003847.jpg' b'003876.jpg' b'000516.jpg' b'002107.jpg' b'003941.jpg'
b'004006.jpg' b'000632.jpg' b'000080.jpg' b'002286.jpg' b'003046.jpg'
b'000785.jpg' b'001122.jpg']
====================
[b'001681.jpg' b'001067.jpg' b'002432.jpg' b'002173.jpg' b'001478.jpg'
b'000610.jpg' b'001715.jpg' b'002695.jpg' b'004003.jpg' b'004100.jpg'
b'002240.jpg' b'000286.jpg' b'003298.jpg' b'000760.jpg' b'003712.jpg'
b'003076.jpg' b'003598.jpg' b'000423.jpg' b'003211.jpg' b'002405.jpg'
b'001274.jpg' b'003872.jpg' b'004079.jpg' b'000486.jpg' b'004012.jpg'
b'003247.jpg' b'001156.jpg' b'004073.jpg' b'002359.jpg' b'000636.jpg'
b'000349.jpg' b'001392.jpg']
====================
[b'001704.jpg' b'001051.jpg' b'002887.jpg' b'003227.jpg' b'000357.jpg'
b'003706.jpg' b'003297.jpg' b'004016.jpg' b'002112.jpg' b'002975.jpg'
b'004077.jpg' b'002272.jpg' b'001991.jpg' b'000694.jpg' b'001515.jpg'
b'000242.jpg' b'002169.jpg' b'003926.jpg' b'001462.jpg' b'002646.jpg'
b'003214.jpg' b'000487.jpg' b'000326.jpg' b'001344.jpg' b'001069.jpg'
b'003025.jpg' b'002724.jpg' b'002502.jpg' b'002479.jpg' b'004098.jpg'
b'001749.jpg' b'003203.jpg']
====================
[b'003235.jpg' b'003145.jpg' b'000356.jpg' b'003175.jpg' b'001426.jpg'
b'003209.jpg' b'004105.jpg' b'002073.jpg' b'003118.jpg' b'003629.jpg'
b'001634.jpg' b'003120.jpg' b'000098.jpg' b'001096.jpg' b'001607.jpg'
b'003158.jpg' b'004115.jpg' b'000084.jpg' b'003362.jpg' b'003666.jpg'
b'001573.jpg' b'002369.jpg' b'002097.jpg' b'003621.jpg' b'003484.jpg'
b'003809.jpg' b'001107.jpg' b'001207.jpg' b'003556.jpg' b'003763.jpg'
b'003594.jpg' b'001101.jpg']
====================
[b'000073.jpg' b'003798.jpg' b'002839.jpg' b'002614.jpg' b'002921.jpg'
b'002453.jpg' b'003261.jpg' b'002648.jpg' b'002605.jpg' b'003388.jpg'
b'003010.jpg' b'000752.jpg' b'003783.jpg' b'001673.jpg' b'002732.jpg'
b'002936.jpg' b'001997.jpg' b'003518.jpg' b'001005.jpg' b'002789.jpg'
b'001082.jpg' b'003087.jpg' b'003873.jpg' b'003871.jpg' b'001441.jpg'
b'003494.jpg' b'000135.jpg' b'001564.jpg' b'000410.jpg' b'002700.jpg'
b'001258.jpg' b'003723.jpg']
====================
[b'000948.jpg' b'003301.jpg' b'003280.jpg' b'001173.jpg' b'002086.jpg'
b'001553.jpg' b'001125.jpg' b'003796.jpg' b'002469.jpg' b'000866.jpg'
b'003491.jpg' b'003708.jpg' b'004152.jpg' b'001616.jpg' b'003965.jpg'
b'002069.jpg' b'002966.jpg' b'000739.jpg' b'001433.jpg' b'000419.jpg'
b'001955.jpg' b'003578.jpg' b'003493.jpg' b'000992.jpg' b'001333.jpg'
b'004042.jpg' b'003442.jpg' b'001623.jpg' b'003615.jpg' b'004140.jpg'
b'003635.jpg' b'000619.jpg']
====================
[b'002193.jpg' b'002691.jpg' b'000456.jpg' b'002500.jpg' b'001423.jpg'
b'003624.jpg' b'002149.jpg' b'000743.jpg' b'001570.jpg' b'002141.jpg'
b'002891.jpg' b'000467.jpg' b'002985.jpg' b'003384.jpg' b'003971.jpg'
b'003143.jpg' b'001541.jpg' b'003032.jpg' b'002317.jpg' b'003951.jpg'
b'001980.jpg' b'000183.jpg' b'002111.jpg' b'001115.jpg' b'000163.jpg'
b'000381.jpg' b'004301.jpg' b'001529.jpg' b'002506.jpg' b'003976.jpg'
b'003886.jpg' b'002414.jpg']
====================
[b'001126.jpg' b'000007.jpg' b'002410.jpg' b'002568.jpg' b'003724.jpg'
b'002947.jpg' b'003988.jpg' b'004004.jpg' b'002682.jpg' b'003284.jpg'
b'000003.jpg' b'001234.jpg' b'001080.jpg' b'002395.jpg' b'000085.jpg'
b'002064.jpg' b'000646.jpg' b'003652.jpg' b'004264.jpg' b'000577.jpg'
b'004320.jpg' b'003726.jpg' b'003859.jpg' b'001369.jpg' b'001056.jpg'
b'003422.jpg' b'003193.jpg' b'001178.jpg' b'000918.jpg' b'000509.jpg'
b'000296.jpg' b'003273.jpg']
====================
[b'001889.jpg' b'003185.jpg' b'000029.jpg' b'002218.jpg' b'001762.jpg'
b'003392.jpg' b'002634.jpg' b'001382.jpg' b'001100.jpg' b'000779.jpg'
b'000544.jpg' b'003537.jpg' b'002630.jpg' b'004138.jpg' b'000539.jpg'
b'002091.jpg' b'000378.jpg' b'002754.jpg' b'002377.jpg' b'002861.jpg'
b'002858.jpg' b'003162.jpg' b'002898.jpg' b'004361.jpg' b'003058.jpg'
b'001686.jpg' b'000629.jpg' b'002349.jpg' b'001722.jpg' b'002675.jpg'
b'002903.jpg' b'001424.jpg']
====================
[b'000194.jpg' b'001345.jpg' b'003553.jpg' b'003031.jpg' b'001821.jpg'
b'003232.jpg' b'000852.jpg' b'003112.jpg' b'000841.jpg' b'002522.jpg'
b'004280.jpg' b'000538.jpg' b'000830.jpg' b'003934.jpg' b'003596.jpg'
b'002004.jpg' b'000794.jpg' b'002015.jpg' b'002620.jpg' b'001974.jpg'
b'003632.jpg' b'002461.jpg' b'003142.jpg' b'000532.jpg' b'001941.jpg'
b'002232.jpg' b'000924.jpg' b'003882.jpg' b'000489.jpg' b'002766.jpg'
b'001795.jpg' b'003866.jpg']
====================
Process finished with exit code 0
总结一下,使用tf.data.Dataset的主要步骤有:
1. 使用tf.data.Dataset.from_tensor_slices从输入的tensor'中获取数据,这个tensor可以是图片的路径组成的list;
2. 使用dataset.filter选择是否过滤数据;
3. 使用dataset.map解析数据,如果读入的只是路径,那么使用这个函数可以将路径对应的图片数据读进来;
4. 使用dataset.shuffle打乱数据;
5. 使用dataset.batch生成batches;
6. 使用dataset.repeat重复数据集;
7. 使用dataset.prefetch设置在数据请求前预加载的batch数量;
8. 使用iterator = dataset.make_initializable_iterator(), batch_data = iterator.get_next()建立迭代器
9. 在session中初始化迭代器sess.run(iterator.initializer),并通过迭代器的get_next()获取一个batch的数据。
2.tfrecord
2.1 使用tfrecord的原因
正常情况下我们训练文件夹经常会生成 train, test 或者val文件夹,这些文件夹内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时候会非常慢,繁琐。占用大量内存空间(有的大型数据不足以一次性加载)。此时我们TFRecord格式的文件存储形式会很合理的帮我们存储数据。
TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。
2.2 tfrecord的写入
# 声明一个TFRecordWriter,才能将信息写入TFRecord文件
# 其中output表示为存储的路径,如“output.tfrecord”
writer = tf.python_io.TFRecordWriter(output)
# 读取图片并进行解码, input是图片路径
image = Image.open("image.jpg")
shape = image.shape
# 将图片转换成 string。
image_data = image.tostring()
name = bytes("cat", encoding='utf8')
# 创建Example对象,并将Feature一一填充进去
example = tf.train.Example(features=tf.train.Features(feature={
'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
}
))
# 将 example 序列化成 string 类型,然后写入
writer.write(example.SerializeToString())
# 全部写完之后,关闭writer
writer.close()
总结一下,tfrecord主要分为4步:
-
声明TFRecordWriter;
-
创建Example对象,并将Feature存入;
-
将Example序列化成string类型写入;
-
关闭writer
2.3 tfrecord的读取
# 定义一个reader
reader = tf.TFRecordReader()
# 读取tfrecord文件,得到一个filename_queue【中括号必须保存下来】
filename_queue=tf.train.string_input_producer(['titanic_train.tfrecords'])
# 返回文件名和文件
_,serialized_example=reader.read(filename_queue)
# 上面的serialized_example是无法直接查看的,需要去按照特征进行解析
features = tf.parse_single_example(serialized_example,features={
'imgae': tf.FixedLenFeature([], tf.string)
'label': tf.FixedLenFeature([], tf.string)
})
# 每次将数据包装成一个batch,capacity为队列能够容纳的最大元素个数
image, label = tf.train.shuffle_batch([features['image'], features['label']], batch_size=16, capacity=500)
with tf.Session() as sess:
tf.global_variables_initializer().run()
# 创建 Coordinator, 负责实现数据输入线程的同步
coord = tf.train.Coordinator()
# 启动队列
threads=tf.train.start_queue_runners(sess=sess, coord)
# 喂数据实现训练
img, lab = sess.run([image, label])
# 线程同步
coord.request_stop()
coord.join(threads=threads)
总结一下,tfrecord读取数据的步骤主要有:
-
使用tf.TFRecordReader()定义reader
-
使用tf.train.string_input_producer读取tfrecord文件
-
使用reader.read读取第二部返回的文件名队列
-
使用tf.parse_single_example解析文件名队列
-
使用tf.train.shuffle_batch将数据打乱并包装成一个batch
-
在session通过tf.train.Coordinator()实现线程同步
-
threads=tf.train.start_queue_runners(sess=sess, coord)启动队列
-
通过coord.request_stop()和coord.join(threads=threads)实现线程同步
3.两种方式的区别
tfrecord需要提前将数据存成tfrecord文件,这样可以减少每次打开文件的时间消耗,针对于大规模数据集训练模型上有帮助。但是问题在于这种方式比较死板,如果有新的数据集,就需要继续生成tfrecord文件。
而tf.data.Dataset的方式就比较灵活了,采用pipeline的方式,在GPU训练数据时,CPU准备数据,不需要提前生成其他文件。
总之,这两种方式都可以处理大规模数据集的训练,但个人觉得tf.data.Dataset要好用一些。
参考资料:
TensorFlow之tfrecords文件详细教程(https://blog.csdn.net/qq_27825451/article/details/83301811)
【Tensorflow】你可能无法回避的 TFRecord 文件格式详细讲解(https://blog.csdn.net/briblue/article/details/80789608)
TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制(https://blog.csdn.net/guyuealian/article/details/85106012)