TensorFlow版本:1.12.0
本篇主要介绍怎么使用 tf.data
API 来构建高性能的输入 pipeline。
tf.data
官方教程详见前面的博客<<<<<<<<<<tf.data
官方教程
GPU、TPU的使用能够从根本上减少单个训练step所需的时间。但优异的性能不仅依赖于高速的计算硬件,也要求有一个高效的输入管道(Input Pipeline Performance Guide),这个管道在当前step完成前,进行下一个 step 需要的数据的准备。
tf.data
API 对于灵活且高效的输入管道的建立非常有帮助。这个文档解释了
tf.data
API 的特性,并介绍了构建高性能的 TensorFlow 数据输入管道的过程。
本文主要包含以下内容:
- 介绍数据输入管道的结构(本质是一个 ETL 过程)。
- 在
tf.data
中,优化数据输入管道的常用方法。 - 介绍了数据操作顺序对数据输入管道性能的影响。
- 优异的数据输入管道应该具备的一些特质。
1. 数据输入管道的结构
TensorFlow数据输入管道可以被抽象为一个 ETL 过程(Extract,Transform,Load):
- Extract:从硬盘上读取数据 ------ 可以是本地(HDD 或 SSD),也可以是网盘(GCS 或 HDFS)
- Transform:使用 CPU 去解析、预处理数据 ------ 比如:图像解码、数据增强、变换(比如:随机裁剪、翻转、颜色变换)、打乱、batching。
- Load:将 Transform 后的数据加载到 计算设备 ------ 例如:GPU、TPU 等设备。
上述的数据输入管道使用 CPU 来进行数据的 ETL 过程,从而让 GPU、TPU 等设备专心进行模型的训练过程(提高了设备的利用率)。另外,将数据输入管道抽象为 ETL 过程,有利于我们对数据输入管道进行优化。
当使用 tf.estimator.Estimator
API 时,input_fn
需要完成 Extract 和 Transform 两个阶段。
def parse_fn(example):
"Parse TFExample records and perform simple data augmentation."
example_fmt = {
"img_encoded": tf.FixedLenFeature((), tf.string, ""),
"img_label": tf.FixedLenFeature((), tf.int64, -1)
}
parsed = tf.parse_single_example(example, example_fmt)
image = tf.image.decode_image(parsed["img_encoded"])
return image, parsed["img_label"