tensorflow学习笔记(3):使用tf.data API导入数据

本文详述了如何使用 TensorFlow 的 tf.data API 构建复杂输入管道,涵盖了从定义数据源、元素转换到批处理和预处理等步骤。通过实例展示了如何处理图像、文本数据,以及如何实现数据的随机重排和多周期训练。
摘要由CSDN通过智能技术生成

详细教程:https://www.tensorflow.org/programmers_guide/datasets
通过 tf.data API,您可以根据简单的可重用片段构建复杂的输入管道。例如,

  • 图片模型的管道可能会汇聚分布式文件系统中的文件中的数据、对每个图片应用随机扰动,并将随机选择的图片合并成用于训练的批次。
  • 文本模型的管道可能包括从原始文本数据中提取符号、根据对照表将其转换为嵌入标识符,以及将不同长度的序列组合成批次数据。

使用 tf.data API 可以轻松__处理大量数据、不同的数据格式__以及__复杂的转换__。

#1. 定义来源
要启动输入管道,您必须定义来源。例如,

  • 要通过内存中的某些张量构建 Dataset,您可以使用 tf.data.Dataset.from_tensors() 或 tf.data.Dataset.from_tensor_slices()
  • 如果您的输入数据以推荐的 TFRecord 格式存储在磁盘上,那么您可以构建 tf.data.TFRecordDataset。

一个数据集包含多个元素,每个元素的结构都相同。一个元素包含一个或多个 tf.Tensor 对象,这些对象称为组件。
每个组件都有一个 tf.DType,表示张量中元素的类型;以及一个 tf.TensorShape,表示每个元素(可能部分指定)的静态形状。

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
#通过 Dataset.output_types 和 Dataset.output_shapes 属性
#检查数据集元素各个组件的推理类型和形状
print(dataset1.output_types)  # ==> "tf.float32"
print(dataset1.output_shapes)  # ==> "(10,)"

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random_uniform([4]),
    tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types)  # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes)  # ==> "((), (100,))"

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types)  # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes)  # ==> "(10, ((), (100,)))"

#为元素的每个组件命名通常会带来便利性
dataset = tf.data.Dataset.from_tensor_slices(
   {
   "a": tf.random_uniform([4]),
    "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"

#2. 元素转换
有了 Dataset 对象以后,您就可以通过链接 tf.data.Dataset 对象上的方法调用将其转换为新的 Dataset。例如,

  • 应用单元素转换,例如 Dataset.map()(为每个元素应用一个函数)
  • 应用多元素转换(例如 Dataset.batch())

Dataset 转换支持任何结构的数据集。在使用 Dataset.map()、Dataset.flat_map() 和 Dataset.filter() 转换时(这些转换会对每个元素应用一个函数),元素结构决定了函数的参数:

dataset1 = dataset1.map(lambda x: ...)

dataset2 = dataset2.flat_map(lambda x, y: ...)

# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)

#3. 构建迭代器对象
消耗 Dataset 中值的最常见方法是构建迭代器对象。通过此对象,可以一次访问数据集中的一个元素(例如通过调用 Dataset.make_one_shot_iterator())。tf.data.Iterator 提供了两个指令:

  • Iterator.initializer,您可以通过此指令(重新)初始化迭代器的状态;
  • Iterator.get_next(),此指令返回对应于有符号下一个元素的 tf.Tensor 对象。

##3.1 创建单次迭代器
单次迭代器是最简单的迭代器形式,仅支持对数据集进行__一次迭代__,不需要显式初始化。单次迭代器可以处理基于队列的现有输入管道支持的几乎所有情况,但它们__不支持参数化__。以 Dataset.range() 为例:

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)
  assert i == value

##3.2 创建可初始化迭代器
需要先运行显式 iterator.initializer 指令,才能使用可初始化迭代器。虽然有些不便,但它允许您使用一个或多个 tf.placeholder() 张量(可在初始化迭代器时馈送)参数化数据集的定义。

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={
   max_value: 10})
for i in range(
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值