1,tf-data两个新的抽象类
dataset表示一系列元素,其中每个元素包含一个或多个 Tensor 对象
创建来源(例如 Dataset.from_tensor_slices()),以通过一个或多个 tf.Tensor 对象构建数据集。
应用转换(例如 Dataset.batch()),以通过一个或多个 tf.data.Dataset 对象构建数据集
iterator提供了从数据集中提取元素的主要方法。
Iterator.get_next() 返回的操作会在执行时生成 Dataset 的下一个元素,并且此操作通常充当输入管道代码和模型之间的接口。最简单的迭代器是“单次迭代器”,它与特定的 Dataset 相关联,并对其进行一次迭代。要实现更复杂的用途,您可以通过 Iterator.initializer 操作使用不同的数据集重新初始化和参数化迭代器
2,基本机制
2.1,定义来源
要通过内存中的某些张量构建 Dataset,您可以使用 tf.data.Dataset.from_tensors() 或 tf.data.Dataset.from_tensor_slices()。或者,如果输入数据以推荐的 TFRecord 格式存储在磁盘上,那么您可以构建 tf.data.TFRecordDataset
2.2,有了 Dataset 对象,可以将其转换为新的 Dataset
方法是链接tf.data.Dataset 对象上的方法调用。例如,您可以应用单元素转换,例如 Dataset.map()(为每个元素应用一个函数),也可以应用多元素转换(例如 Dataset.batch())
2.3,消耗 Dataset 中值的最常见方法是构建迭代器对象。
通过此对象,可以一次访问数据集中的一个元素(例如通过调用 Dataset.make_one_shot_iterator())。tf.data.Iterator 提供了两个操作:Iterator.initializer,您可以通过此操作(重新)初始化迭代器的状态;以及 Iterator.get_next(),此操作返回对应于有符号下一个元素的 tf.Tensor 对象
3,数据集结构
一个数据集包含多个元素,每个元素的结构都相同。一个元素包含一个或多个 tf.Tensor 对象,这些对象称为组件。每个组件都有一个 tf.DType,表示张量中元素的类型;以及一个 tf.TensorShape,表示每个元素(可能部分指定)的静态形状。您可以通过 Dataset.output_types 和 Dataset.output_shapes 属性检查数据集元素各个组件的推理类型和形状
dataset1 =tf.data.Dataset.from_tensor_slices(tf.random_uniform([4,10]))print(dataset1.output_types)# ==> "tf.float32"print(dataset1.output_shapes)# ==> "(10,)"dataset2 =tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]),tf.random_uniform([4,100],maxval=100,dtype=tf.int32)))print(dataset2.output_types)# ==> "(tf.float32, tf.int32)"print(dataset2.output_shapes)# ==> "((), (100,))"dataset3 =tf.data.Dataset.zip((dataset1,dataset2))print(dataset3.output_types)# ==> (tf.float32, (tf.float32, tf.int32))print(dataset3.output_shapes)# ==> "(10, ((), (100,)))"
dataset =tf.data.Dataset.from_tensor_slices({"a":tf.random_uniform([4]),"b":tf.random_uniform([4,100],maxval=100,dtype=tf.int32)})print(dataset.output_types)# ==> "{'a': tf.float32, 'b': tf.int32}"print(dataset.output_shapes)# ==> "{'a': (), 'b': (100,)}"
4,Dataset 转换
Dataset 转换支持任何结构的数据集。在使用 Dataset.map()、Dataset.flat_map() 和 Dataset.filter() 转换时(这些转换会对每个元素应用一个函数),元素结构决定了函数的参数.
dataset1 =dataset1.map(lambdax:...)dataset2 =dataset2.flat_map(lambdax,y:...)# Note: Argument destructuring is not available in Python 3.dataset3 =dataset3.filter(lambdax,(y,z):...)
5,创建迭代器
单次,
可初始化,
可重新初始化,以及
可馈送。
单次:
迭代器是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。单次迭代器可以处理基于队列的现有输入管道支持的几乎所有情况,但它们不支持参数化
dataset =tf.data.Dataset.range(100)iterator =dataset.make_one_shot_iterator()next_element =iterator.get_next()fori inrange(100):value =sess.run(next_element)asserti ==value
可初始化:
您需要先运行显式 iterator.initializer 操作,然后才能使用可初始化迭代器.它允许您使用一个或多个 tf.placeholder() 张量(可在初始化迭代器时馈送)参数化数据集的定义max_value =tf.placeholder(tf.int64,shape=[])
dataset =tf.data.Dataset.range(max_value)iterator =dataset.make_initializable_iterator()next_element =iterator.get_next()# Initialize an iterator over a dataset with 10 elements.sess.run(iterator.initializer,feed_dict={max_value:10})fori inrange(10):value =sess.run(next_element)asserti ==value
# Initialize the same iterator over a dataset with 100 elements.sess.run(iterator.initializer,feed_dict={max_value:100})fori inrange(100):value =sess.run(next_element)asserti ==value
可重新初始化:
迭代器可以通过多个不同的 Dataset 对象进行初始化.这些对象具有相同的结构(即每个组件具有相同类型和兼容形状)
# Define training and validation datasets with the same structure.training_dataset =tf.data.Dataset.range(100).map(lambdax:x +tf.random_uniform([],-10,10,tf.int64))validation_dataset =tf.data.Dataset.range(50)# A reinitializable iterator is defined by its structure. We could use the# `output_types` and `output_shapes` properties of either `training_dataset`# or `validation_dataset` here, because they are compatible.iterator =tf.data.Iterator.from_structure(training_dataset.output_types,training_dataset.output_shapes)next_element =iterator.get_next()training_init_op =iterator.make_initializer(training_dataset)validation_init_op =iterator.make_initializer(validation_dataset)# Run 20 epochs in which the training dataset is traversed, followed by the# validation dataset.for_ inrange(20):# Initialize an iterator over the training dataset.sess.run(training_init_op)for_ inrange(100):sess.run(next_element)# Initialize an iterator over the validation dataset.sess.run(validation_init_op)for_ inrange(50):sess.run(next_element)
可馈送
迭代器可以与 tf.placeholder 一起使用,以选择所使用的 Iterator(在每次调用 tf.Session.run 时)(通过熟悉的 feed_dict 机制)。它提供的功能与可重新初始化迭代器的相同,但在迭代器之间切换时不需要从数据集的开头初始化迭代器.tf.data.Iterator.from_string_handle
# Define training and validation datasets with the same structure.training_dataset =tf.data.Dataset.range(100).map(lambdax:x +tf.random_uniform([],-10,10,tf.int64)).repeat()validation_dataset =tf.data.Dataset.range(50)# A feedable iterator is defined by a handle placeholder and its structure. We# could use the `output_types` and `output_shapes` properties of either# `training_dataset` or `validation_dataset` here, because they have# identical structure.handle =tf.placeholder(tf.string,shape=[])iterator =tf.data.Iterator.from_string_handle(handle,training_dataset.output_types,training_dataset.output_shapes)next_element =iterator.get_next()# You can use feedable iterators with a variety of different kinds of iterator# (such as one-shot and initializable iterators).training_iterator =training_dataset.make_one_shot_iterator()validation_iterator =validation_dataset.make_initializable_iterator()# The `Iterator.string_handle()` method returns a tensor that can be evaluated# and used to feed the `handle` placeholder.training_handle =sess.run(training_iterator.string_handle())validation_handle =sess.run(validation_iterator.string_handle())# Loop forever, alternating between training and validation.whileTrue:# Run 200 steps using the training dataset. Note that the training dataset is# infinite, and we resume from where we left off in the previous `while` loop# iteration.for_ inrange(200):sess.run(next_element,feed_dict={handle:training_handle})# Run one pass over the validation dataset.sess.run(validation_iterator.initializer)for_ inrange(50):sess.run(next_element,feed_dict={handle:validation_handle
6,消耗迭代器中的值
Iterator.get_next() 方法返回一个或多个 tf.Tensor 对象,这些对象对应于迭代器有符号的下一个元素。每次评估这些张量时,它们都会获取底层数据集中下一个元素的值。(请注意,与 TensorFlow 中的其他有状态对象一样,调用 Iterator.get_next() 并不会立即使迭代器进入下个状态。您必须在 TensorFlow 表达式中使用此函数返回的 tf.Tensor 对象,并将该表达式的结果传递到 tf.Session.run(),以获取下一个元素并使迭代器进入下个状态。)
如果迭代器到达数据集的末尾,则执行 Iterator.get_next() 操作会产生 tf.errors.OutOfRangeError。在此之后,迭代器将处于不可用状态;如果需要继续使用,则必须对其重新初始化
dataset =tf.data.Dataset.range(5)iterator =dataset.make_initializable_iterator()next_element =iterator.get_next()# Typically `result` will be the output of a model, or an optimizer's# training operation.result =tf.add(next_element,next_element)sess.run(iterator.initializer)print(sess.run(result))# ==> "0"print(sess.run(result))# ==> "2"print(sess.run(result))# ==> "4"print(sess.run(result))# ==> "6"print(sess.run(result))# ==> "8"try:sess.run(result)excepttf.errors.OutOfRangeError:print("End of dataset")# ==> "End of dataset"
sess.run(iterator.initializer)whileTrue:try:sess.run(result)excepttf.errors.OutOfRangeError:break
如果数据集的每个元素都具有嵌套结构,则 Iterator.get_next() 的返回值将是一个或多个 tf.Tensor 对象,这些对象具有相同的嵌套结构:
dataset1 =tf.data.Dataset.from_tensor_slices(tf.random_uniform([4,10]))dataset2 =tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]),tf.random_uniform([4,100])))dataset3 =tf.data.Dataset.zip((dataset1,dataset2))iterator =dataset3.make_initializable_iterator()sess.run(iterator.initializer)next1,(next2,next3)=iterator.get_next()
请注意,next1、next2 和 next3 是由同一个操作/节点(通过 Iterator.get_next() 创建)生成的张量。因此,评估其中任何一个张量都会使所有组件的迭代器进入下个状态。典型的迭代器消耗方会在一个表达式中包含所有组件
7,保存迭代器状态
tf.contrib.data.make_saveable_from_iterator 函数通过迭代器创建一个 SaveableObject,该对象可用于保存和恢复迭代器(实际上是整个输入管道)的当前状态。以这种方式创建的可保存对象可以添加到 tf.train.Saver 变量列表或 tf.GraphKeys.SAVEABLE_OBJECTS 集合中,以便采用与 tf.Variable 相同的方式进行保存和恢复。请参阅保存和恢复,详细了解如何保存和恢复变量。
# Create saveable object from iterator.saveable =tf.contrib.data.make_saveable_from_iterator(iterator)# Save the iterator state by adding it to the saveable objects collection.tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS,saveable)saver =tf.train.Saver()withtf.Session()assess:ifshould_checkpoint:saver.save(path_to_checkpoint)# Restore the iterator state.withtf.Session()assess:saver.restore(sess,path_to_checkpoint)
8,读取输入数据
8.1,消耗 NumPy 数组
# Load the training data into two NumPy arrays, for example using `np.load()`.withnp.load("/var/data/training_data.npy")asdata:features =data["features"]labels =data["labels"]# Assume that each row of `features` corresponds to the same row as `labels`.assertfeatures.shape[0]==labels.shape[0]dataset =tf.data.Dataset.from_tensor_slices((features,labels))
请注意,上面的代码段会将 features 和 labels 数组作为 tf.constant() 指令嵌入在 TensorFlow 图中。这样非常适合小型数据集,但会浪费内存,因为会多次复制数组的内容,并可能会达到 tf.GraphDef 协议缓冲区的 2GB 上限。
作为替代方案,您可以根据 tf.placeholder() 张量定义 Dataset,并在对数据集初始化 Iterator 时馈送 NumPy 数组。
# Load the training data into two NumPy arrays, for example using `np.load()`.withnp.load("/var/data/training_data.npy")asdata:features =data["features"]labels =data["labels"]# Assume that each row of `features` corresponds to the same row as `labels`.assertfeatures.shape[0]==labels.shape[0]features_placeholder =tf.placeholder(features.dtype,features.shape)labels_placeholder =tf.placeholder(labels.dtype,labels.shape)dataset =tf.data.Dataset.from_tensor_slices((features_placeholder,labels_placeholder))# [Other transformations on `dataset`...]dataset =...iterator =dataset.make_initializable_iterator()sess.run(iterator.initializer,feed_dict={features_placeholder:features,labels_placeholder:labels})
8.2,消耗 TFRecord 数据
tf.data API 支持多种文件格式,因此您可以处理那些不适合存储在内存中的大型数据集。例如,TFRecord 文件格式是一种面向记录的简单二进制格式,很多 TensorFlow 应用采用此格式来训练数据。通过 tf.data.TFRecordDataset 类,您可以将一个或多个 TFRecord 文件的内容作为输入管道的一部分进行流式传输
# Creates a dataset that reads all of the examples from two files.filenames =["/var/data/file1.tfrecord","/var/data/file2.tfrecord"]dataset =tf.data.TFRecordDataset(filenames)
TFRecordDataset 初始化程序的 filenames 参数可以是字符串、字符串列表,也可以是字符串 tf.Tensor。因此,如果您有两组分别用于训练和验证的文件,则可以使用 tf.placeholder(tf.string) 来表示文件名,并使用适当的文件名初始化迭代器:
filenames =tf.placeholder(tf.string,shape=[None])dataset =tf.data.TFRecordDataset(filenames)dataset =dataset.map(...)# Parse the record into tensors.dataset =dataset.repeat()# Repeat the input indefinitely.dataset =dataset.batch(32)iterator =dataset.make_initializable_iterator()# You can feed the initializer with the appropriate filenames for the current# phase of execution, e.g. training vs. validation.# Initialize `iterator` with training data.training_filenames =["/var/data/file1.tfrecord","/var/data/file2.tfrecord"]sess.run(iterator.initializer,feed_dict={filenames:training_filenames})# Initialize `iterator` with validation data.validation_filenames =["/var/data/validation1.tfrecord",...]sess.run(iterator.initializer,feed_dict={filenames:validation_filenames})
8.3,消耗文本数据
filenames =["/var/data/file1.txt","/var/data/file2.txt"]dataset =tf.data.TextLineDataset(filenames)
默认情况下,TextLineDataset 会生成每个文件的每一行,这可能是不可取的(例如,如果文件以标题行开头或包含注释)。可以使用 Dataset.skip() 和 Dataset.filter() 转换来移除这些行。为了将这些转换分别应用于每个文件,我们使用 Dataset.flat_map() 为每个文件创建一个嵌套的 Dataset。
filenames =["/var/data/file1.txt","/var/data/file2.txt"]dataset =tf.data.Dataset.from_tensor_slices(filenames)# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,# and then concatenate their contents sequentially into a single "flat" dataset.# * Skip the first line (header row).# * Filter out lines beginning with "#" (comments).dataset =dataset.flat_map(lambdafilename:(tf.data.TextLineDataset(filename).skip(1).filter(lambdaline:tf.not_equal(tf.substr(line,0,1),"#"))))
8.4,消耗 CSV 数据
给定一个或多个文件名以及默认值列表后,CsvDataset 将生成一个元素元组,元素类型对应于为每个 CSV 记录提供的默认元素类型
# Creates a dataset that reads all of the records from two CSV files, each with# eight float columnsfilenames =["/var/data/file1.csv","/var/data/file2.csv"]record_defaults =[tf.float32]*8# Eight required float columnsdataset =tf.contrib.data.CsvDataset(filenames,record_defaults)
# Creates a dataset that reads all of the records from two CSV files, each with# four float columns which may have missing valuesrecord_defaults =[[0.0]]*8dataset =tf.contrib.data.CsvDataset(filenames,record_defaults)
# Creates a dataset that reads all of the records from two CSV files with# headers, extracting float data from columns 2 and 4.record_defaults =[[0.0]]*2# Only provide defaults for the selected columnsdataset =tf.contrib.data.CsvDataset(filenames,record_defaults,header=True,select_cols=[2,4])
9,使用 Dataset.map() 预处理数据
Dataset.map(f) 转换通过将指定函数 f 应用于输入数据集的每个元素来生成新数据集
解析 tf.Example 协议缓冲区消息
# Transforms a scalar string `example_proto` into a pair of a scalar string and# a scalar integer, representing an image and its label, respectively.def_parse_function(example_proto):features ={"image":tf.FixedLenFeature((),tf.string,default_value=""),"label":tf.FixedLenFeature((),tf.int64,default_value=0)}parsed_features =tf.parse_single_example(example_proto,features)returnparsed_features["image"],parsed_features["label"]# Creates a dataset that reads all of the examples from two files, and extracts# the image and label features.filenames =["/var/data/file1.tfrecord","/var/data/file2.tfrecord"]dataset =tf.data.TFRecordDataset(filenames)dataset =dataset.map(_parse_function)
解码图片数据并调整其大小,use map
# Reads an image from a file, decodes it into a dense tensor, and resizes it# to a fixed shape.def_parse_function(filename,label):image_string =tf.read_file(filename)image_decoded =tf.image.decode_jpeg(image_string)image_resized =tf.image.resize_images(image_decoded,[28,28])returnimage_resized,label
# A vector of filenames.filenames =tf.constant(["/var/data/image1.jpg","/var/data/image2.jpg",...])# `labels[i]` is the label for the image in `filenames[i].labels =tf.constant([0,37,...])dataset =tf.data.Dataset.from_tensor_slices((filenames,labels))dataset =dataset.map(_parse_function)
使用 tf.py_func() 应用任意 Python 逻辑
为了确保性能,我们建议您尽可能使用 TensorFlow 指令预处理数据。不过,在解析输入数据时,调用外部 Python 库有时很有用。为此,请在 Dataset.map() 转换中调用 tf.py_func() 指令
importcv2
# Use a custom OpenCV function to read the image, instead of the standard# TensorFlow `tf.read_file()` operation.def_read_py_function(filename,label):image_decoded =cv2.imread(filename.decode(),cv2.IMREAD_GRAYSCALE)returnimage_decoded,label
# Use standard TensorFlow operations to resize the image to a fixed shape.def_resize_function(image_decoded,label):image_decoded.set_shape([None,None,None])image_resized =tf.image.resize_images(image_decoded,[28,28])returnimage_resized,label
filenames =["/var/data/image1.jpg","/var/data/image2.jpg",...]labels =[0,37,29,1,...]dataset =tf.data.Dataset.from_tensor_slices((filenames,labels))dataset =dataset.map(lambdafilename,label:tuple(tf.py_func(_read_py_function,[filename,label],[tf.uint8,label.dtype])))dataset =dataset.map(_resize_function)
批处理数据集元素
最简单的批处理形式是将数据集中的 n 个连续元素堆叠为一个元素。Dataset.batch() 转换正是这么做的,它与 tf.stack() 运算符具有相同的限制(被应用于元素的每个组件):即对于每个组件 i,所有元素的张量形状都必须完全相同
inc_dataset =tf.data.Dataset.range(100)dec_dataset =tf.data.Dataset.range(0,-100,-1)dataset =tf.data.Dataset.zip((inc_dataset,dec_dataset))batched_dataset =dataset.batch(4)iterator =batched_dataset.make_one_shot_iterator()next_element =iterator.get_next()print(sess.run(next_element))# ==> ([0, 1, 2, 3], [ 0, -1, -2, -3])print(sess.run(next_element))# ==> ([4, 5, 6, 7], [-4, -5, -6, -7])print(sess.run(next_element))# ==> ([8, 9, 10, 11], [-8, -9, -10, -11])
使用填充批处理张量
dataset =tf.data.Dataset.range(100)dataset =dataset.map(lambdax:tf.fill([tf.cast(x,tf.int32)],x))dataset =dataset.padded_batch(4,padded_shapes=[None])iterator =dataset.make_one_shot_iterator()next_element =iterator.get_next()print(sess.run(next_element))# ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]print(sess.run(next_element))# ==> [[4, 4, 4, 4, 0, 0, 0],# [5, 5, 5, 5, 5, 0, 0],# [6, 6, 6, 6, 6, 6, 0],# [7, 7, 7, 7, 7, 7, 7]]
您可以通过 Dataset.padded_batch() 转换为每个组件的每个维度设置不同的填充,并且可以采用可变长度(在上面的示例中用 None 表示)或恒定长度。也可以替换填充值,默认设置为 0
10,训练工作流程
要迭代数据集多个周期,最简单的方法是使用 Dataset.repeat() 转换
filenames =["/var/data/file1.tfrecord","/var/data/file2.tfrecord"]dataset =tf.data.TFRecordDataset(filenames)dataset =dataset.map(...)dataset =dataset.repeat(10)dataset =dataset.batch(32)
应用不带参数的 Dataset.repeat() 转换将无限次地重复输入。Dataset.repeat() 转换将其参数连接起来,而不会在一个周期结束和下一个周期开始时发出信号。
如果您想在每个周期结束时收到信号,则可以编写在数据集结束时捕获 tf.errors.OutOfRangeError 的训练循环。此时,您可以收集关于该周期的一些统计信息(例如验证错误)
filenames =["/var/data/file1.tfrecord","/var/data/file2.tfrecord"]dataset =tf.data.TFRecordDataset(filenames)dataset =dataset.map(...)dataset =dataset.batch(32)iterator =dataset.make_initializable_iterator()next_element =iterator.get_next()# Compute for 100 epochs.for_ inrange(100):sess.run(iterator.initializer)whileTrue:try:sess.run(next_element)excepttf.errors.OutOfRangeError:break# [Perform end-of-epoch calculations here.]
随机重排输入数据
Dataset.shuffle() 转换会使用类似于 tf.RandomShuffleQueue 的算法随机重排输入数据集:它会维持一个固定大小的缓冲区,并从该缓冲区统一地随机选择下一个元素
filenames =["/var/data/file1.tfrecord","/var/data/file2.tfrecord"]dataset =tf.data.TFRecordDataset(filenames)dataset =dataset.map(...)dataset =dataset.shuffle(buffer_size=10000)dataset =dataset.batch(32)dataset =dataset.repeat()
11,使用高阶 API
tf.train.MonitoredTrainingSession API 简化了在分布式设置下运行 TensorFlow 的很多方面。MonitoredTrainingSession 使用 tf.errors.OutOfRangeError 表示训练已完成,因此要将其与 tf.data API 结合使用,我们建议使用 Dataset.make_one_shot_iterator()
filenames =["/var/data/file1.tfrecord","/var/data/file2.tfrecord"]dataset =tf.data.TFRecordDataset(filenames)dataset =dataset.map(...)dataset =dataset.shuffle(buffer_size=10000)dataset =dataset.batch(32)dataset =dataset.repeat(num_epochs)iterator =dataset.make_one_shot_iterator()next_example,next_label =iterator.get_next()loss =model_function(next_example,next_label)training_op =tf.train.AdagradOptimizer(...).minimize(loss)withtf.train.MonitoredTrainingSession(...)assess:whilenotsess.should_stop():sess.run(training_op)
要在 input_fn 中使用 Dataset(input_fn 属于 tf.estimator.Estimator),只需返回 Dataset 即可,框架将负责为您创建和初始化迭代器。例如:
defdataset_input_fn():filenames =["/var/data/file1.tfrecord","/var/data/file2.tfrecord"]dataset =tf.data.TFRecordDataset(filenames)# Use `tf.parse_single_example()` to extract data from a `tf.Example`# protocol buffer, and perform any additional per-record preprocessing.defparser(record):keys_to_features ={"image_data":tf.FixedLenFeature((),tf.string,default_value=""),"date_time":tf.FixedLenFeature((),tf.int64,default_value=""),"label":tf.FixedLenFeature((),tf.int64,default_value=tf.zeros([],dtype=tf.int64)),}parsed =tf.parse_single_example(record,keys_to_features)# Perform additional preprocessing on the parsed data.image =tf.image.decode_jpeg(parsed["image_data"])image =tf.reshape(image,[299,299,1])label =tf.cast(parsed["label"],tf.int32)return{"image_data":image,"date_time":parsed["date_time"]},label
# Use `Dataset.map()` to build a pair of a feature dictionary and a label# tensor for each example.dataset =dataset.map(parser)dataset =dataset.shuffle(buffer_size=10000)dataset =dataset.batch(32)dataset =dataset.repeat(num_epochs)# Each element of `dataset` is tuple containing a dictionary of features# (in which each value is a batch of values for that feature), and a batch of# labels.returndataset