详细教程:https://www.tensorflow.org/programmers_guide/datasets
通过 tf.data API,您可以根据简单的可重用片段构建复杂的输入管道。例如,
- 图片模型的管道可能会汇聚分布式文件系统中的文件中的数据、对每个图片应用随机扰动,并将随机选择的图片合并成用于训练的批次。
- 文本模型的管道可能包括从原始文本数据中提取符号、根据对照表将其转换为嵌入标识符,以及将不同长度的序列组合成批次数据。
使用 tf.data API 可以轻松__处理大量数据、不同的数据格式__以及__复杂的转换__。
#1. 定义来源
要启动输入管道,您必须定义来源。例如,
- 要通过内存中的某些张量构建 Dataset,您可以使用 tf.data.Dataset.from_tensors() 或 tf.data.Dataset.from_tensor_slices()
- 如果您的输入数据以推荐的 TFRecord 格式存储在磁盘上,那么您可以构建 tf.data.TFRecordDataset。
一个数据集包含多个元素,每个元素的结构都相同。一个元素包含一个或多个 tf.Tensor 对象,这些对象称为组件。
每个组件都有一个 tf.DType,表示张量中元素的类型;以及一个 tf.TensorShape,表示每个元素(可能部分指定)的静态形状。
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
#通过 Dataset.output_types 和 Dataset.output_shapes 属性
#检查数据集元素各个组件的推理类型和形状
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))"
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"
#为元素的每个组件命名通常会带来便利性
dataset = tf.data.Dataset.from_tensor_slices(
{
"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"
#2. 元素转换
有了 Dataset 对象以后,您就可以通过链接 tf.data.Dataset 对象上的方法调用将其转换为新的 Dataset。例如,
- 应用单元素转换,例如 Dataset.map()(为每个元素应用一个函数)
- 应用多元素转换(例如 Dataset.batch())
Dataset 转换支持任何结构的数据集。在使用 Dataset.map()、Dataset.flat_map() 和 Dataset.filter() 转换时(这些转换会对每个元素应用一个函数),元素结构决定了函数的参数:
dataset1 = dataset1.map(lambda x: ...)
dataset2 = dataset2.flat_map(lambda x, y: ...)
# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)
#3. 构建迭代器对象
消耗 Dataset 中值的最常见方法是构建迭代器对象。通过此对象,可以一次访问数据集中的一个元素(例如通过调用 Dataset.make_one_shot_iterator())。tf.data.Iterator 提供了两个指令:
- Iterator.initializer,您可以通过此指令(重新)初始化迭代器的状态;
- Iterator.get_next(),此指令返回对应于有符号下一个元素的 tf.Tensor 对象。
##3.1 创建单次迭代器
单次迭代器是最简单的迭代器形式,仅支持对数据集进行__一次迭代__,不需要显式初始化。单次迭代器可以处理基于队列的现有输入管道支持的几乎所有情况,但它们__不支持参数化__。以 Dataset.range() 为例:
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
##3.2 创建可初始化迭代器
需要先运行显式 iterator.initializer 指令,才能使用可初始化迭代器。虽然有些不便,但它允许您使用一个或多个 tf.placeholder() 张量(可在初始化迭代器时馈送)参数化数据集的定义。
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={
max_value: 10})
for i in range(