数据集Dataset
TensorFlow提供一套高层的数据处理框架,将每一个数据来源抽象成一个“数据集”,开发者可以以数据集为基本对象,方便进行batching、shuffle等操作。
数据集读取数据基本步骤:
- 定义数据集构造方法:不同数据来源调用不同构造方法(张量——tf.data.Dataset.from_tensor_slices()、文本文件——tf.data.TextLineDataset()、TFRecord——tf.data.TFRecordDataset())
- 定义遍历器:主要方式有两种,make_one_shot_iterator()、make_initializable_iterator(),第二种更加灵活,对应placeholder使用,后面的代码会详细标明如何搭配使用
- 使用get_next()方法从迭代器中读取数据张量,作为计算图其他部分的输入
具体实例:
import tensorflow as tf
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
iterator = dataset.make_initializable_iterator()
x = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer,
feed_dict={
input_files: ["", ""]
})
# 遍历结束时抛出OutofRange异常以结束程序
while True:
try:
sess.run(x)
except tf.errors.OutOfRangeError:
break
数据集中更高层的操作:
dataset = dataset.shuffle(buffer_size)
随机打乱顺序
buffer_size限制buffer缓冲区中的最少元素个数,缓冲区的大小越大,随机的性能越好,但占用的内存越多。
dataset = dataset.batch(batch_size)
batch_size代表要输出的每个batch的数据条数
若每个数据为image、label两个张量(即iterator.get_next()的一个返回值),其中image:[300, 300],label:[],batch_size=128,则经过batch操作后数据集的每个输出为:[128, 300, 300], [128]