tf.data API带来了TensorFlow的两种新抽象:
tf.data.Dataset : 表示元素的序列,其中每个元素包含了一个或多个Tensor对象。例如,一个图像数据管道中,一个元素可能是一个具有一对张量表示其图像数据和标签的训练样本 。有两个不同的方法创建dataset :
创造source (例如Dataset.from_tensor_slices()) 从一个或多个tf.Tensor对象中构建dataset
使用transformation(例如Dataset.batch()) 从一个或多个tf.data.Dataset对象中构建dataset
tf.data.Iterator : 是从dataset中提取元素的主要方式。该操作通过 Iterator.get_next() yield出 Dataset的下一个元素,一般作为输入管道和模型之间的接口。最简单的迭代器是 “one-shot iterator” ,用来迭代一次某个特定的Dataset。更复杂的使用中,Iterator.initializer操作可以用不同的dataset重新初始并且参数化一个迭代器,可以以此在一个程序中训练和验证数据多次。
基础机制
启动一个输入管道,需要首先定义source 。 为了从内存的张量中构建Dataset,可以使用 tf.data.Dataset.from_tensors() 或者 tf.data.Dataset.from_tensor_slices()。
拥有了Dataset对象以后,可以将它们转化为新的Dataset。例如Dataset.map() , Dataset.batch() 等
最常用于使用Dataset的方式是使用一个迭代器(例如 Dataset.make_one_shot_iterator())。
tf.data.Iterator提供两种操作:
Iterator.initializer (重新)初始化迭代器
Iterator.get_next() 返回下一个元素的tf.Tensor对象
Dataset构建
dataset包括了含有相同结构的元素。一个元素包含一个或多个tf.Tensor对象,称作组件。每个组件有一个tf.DType 表示元素的类型;还有一个tf.TensorShape表示元素的形状。可以用Dataset.output_types和Dataset.output_shapes来检查:
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))"
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"
给元素中的组件命名也很方便 :
dataset = tf.data.Dataset.from_tensor_slices(
{"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"
以及转变操作 :
dataset1 = dataset1.map(lambda x: ...)
dataset2 = dataset2.flat_map(lambda x, y: ...)
# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)
创建迭代器
有了表示输入数据的Dataset后,下一步是创建可以得到其中元素的迭代器,随着复杂程度的提高,tf.data API 提供一下迭代器:
- one-shot
- initializable
- reinitializable
- feedable
one-shot是最简单的迭代器,不需要明确的初始化,迭代一次,但不支持参数化:
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(100):
value = sess.run(next_element)
assert i == value
print(value)
initializable迭代器需要在使用前进行iterator.initializer的操作,虽然不方便,但支持参数化。可以使用一个或多个tf.placeholder()在初始化迭代器时占位:
max_value = tf.placeholder(tf.int64,shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Initialize an iterator over a dataset with 10 elemnets
with tf.Session() as sess:
sess.run(iterator.initializer , feed_dict={max_value:10})
for i in range(10):
value = sess.run(next_element)
assert i == value
with tf.Session() as sess:
sess.run(iterator.initializer , feed_dict = {max_value:100})
for i in range(100):
value = sess.run(next_element)
assert i == value
reinitializable迭代器可以从多个不同的Dataset对象中初始化。例如,一个使用随机干扰以提升泛化能力的训练输入管道,和一个评估未修改数据上的预测的验证输入管道。这些管道一般使用不同的Dataset对象,但是这些对象有相同的结构:
"""Define training and validation datasets with the same structure."""
training_dataset = tf.data.Dataset.range(100).map(lambda x: x + tf.random_uniform([],-10,10,tf.int64))
validation_dataset = tf.data.Dataset.range(50)
"""
A reinititializable iterator is defined by its structure.
we could use the output_types and output_shapes properties of either training_dataset or
validation_dataset here,because they are compatible
此迭代器构造方法可用于创建一个迭代器,该迭代器可用于许多不同的数据集。
返回的迭代器未绑定到特定的数据集没有初始化,需要make_initializer()
"""
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
"""
Run 20 epochs in which the training dataset is traversed,followed by the validation dataset
"""
for _ in range(20):
#initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)
feedable迭代器可以和tf.placeholder一起使用,通过相似的feed_dict机制,选择每一次tf.Session.run中的迭代器。它提供和reinitializable迭代器相同的功能,但是不需要在选择迭代器启动dataset时就初始化迭代器:
training_dataset = tf.data.Dataset.range(100).map(lambdax:
x+tf.random_uniform([],-10,10,tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
"""
可提供的迭代器由句柄占位符及其结构定义。 我们可以在这里使用“ training_dataset”或“
validation_dataset”的“ output_types”和“ output_shapes”属性,因为它们具有相同的结构。
"""
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle,training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
"""
您可以将feedable迭代器与各种不同的迭代器一起使用
例如单次迭代和可初始化的迭代器)。
"""
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
# Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
# Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})
从迭代器中获取数据
使用Iterator.get_next()方法
当迭代器到达dataset尾部时,运行Iterator.get_next()会raise一个tf.errors.OutOfRangeError,这个迭代器就处于不可用状态,必须重新初始化才可以使用。
dataset = tf.data.Dataset.range(5)
iterator = data.make_initializable_iterator()
next_element = iterator.get_next()
result = tf.add(next_element , next_element)
sess.run(iterator.initializer)
print(sess.run(result)) # ==> "0"
print(sess.run(result)) # ==> "2"
print(sess.run(result)) # ==> "4"
print(sess.run(result)) # ==> "6"
print(sess.run(result)) # ==> "8"
try:
sess.run(result)
except tf.errors.OutOfRangeError:
print("End of dataset") #==> "End of dataset"
如果每个元素都有嵌套结构,那么 Iterator.get_next() 会返回相同结构的一个或多个 tf.Tensor 对象:
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]),
tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()
评估next1,next2,next3中的任意一个,都将推动迭代器
读入输入数据
使用Numpy数组
如果所有的输入数据都在内存里,从中创建dataset最简单的方法是将它们转化为tf.Tensor对象,然后用Dataset.from_tensor_slices() :
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
以上代码将feature和label数组以 tf.constant() 操作组合。这个适用于小数据集,但是浪费内存——会将数组内容复制多次,并且可能会遇到tf.GraphDef规定的2GB缓存限制。
作为替代,可以用tf.placeholder() 张量来定义Dataset , 并在迭代器初始化时,把Numpy数组feed进去。
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
features_placeholder = tf.placeholder(feature.dtype , features.shape)
label_placeholder = tf.placeholder(labels.dtype , labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder , labels_placeholder))
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict = {features_placeholder: features,
labels_placeholder: labels})
原文链接:点此处