TensorFlow提供了一种内置的API—ataset,使得我们可以很容易地就利用输入管道的方式输入数据。可以像队列读取数据那样,生产batch、数据增强等等。
tf.data.Dataset
可以表示为一些元素的序列,该元素序列可以是列表、元组甚至是字典。比如对于图像通道,元素可以是单独的数据样本,也可以是成对的(样本+label),这里提供了两种不同的创建dataset的方式:
Dataset.from_tensor_slices():从数据中返回一个切片,也就是单个数据信息
Dataset.batch():对数据应用变换,使其返回一个batch
tf.data.Iterator是从数据集中提取元素的主要方法,通过Iterator.get_next()产生Dataset下一个元素。最简单的迭代器是"one-shot iterator",它可以对Dataset迭代一次;对于复杂的情况,Iterator.initializer可以让你重新启动和参数化一个迭代器,这样就可以在一个程序中多次加载训练集和验证集。
# 创建一个Dataset,(此例是from tensor创建)
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10])) # tensor
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32))) # 元组
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))"
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"
dataset = tf.data.Dataset.from_tensor_slices(
{"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)}) # 字典
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.o