作者:叶 虎
编辑:赵一帆
前 言
在训练模型时,我们首先要处理的就是训练数据的加载与预处理的问题,这里称这个过程为输入流水线(input pipelines,或输入管道,[参考:https://www.tensorflow.org/performance/datasets_performance])。在TensorFlow中,典型的输入流水线包含三个流程(ETL流程):
-
提取(Extract):从存储介质(如硬盘)中读取数据,可能是本地读取,也可能是远程读取(比如在分布式存储系统HDFS)
-
预处理(Transform):利用CPU处理器解析和预处理提取的数据,如图像解压缩,数据扩增或者变换,然后会做random shuffle,并形成batch。
-
加载(load):将预处理后的数据加载到加速设备中(如GPUs)来执行模型的训练。
输入流水线对于加速模型训练还是很重要的,如果你的CPU处理数据能力跟不上GPU的处理速度,此时CPU预处理数据就成为了训练模型的瓶颈环节。除此之外,上述输入流水线本身也有很多优化的地方。比如,一个典型的模型训练过程中,CPU预处理数据时,GPU是闲置的,当GPU训练模型时,CPU是闲置的,这个过程如下所示:
这样一个训练step中所花费的时间是CPU预处理数据和GPU训练模型时间的总和。显然这个过程中有资源浪费,一个改进的方法就是交叉CPU数据处理和GPU模型训练这两个过程,当GPU处于第个训练阶段,CPU正在准备第N+1步所需的数据,如下图所示:
明显上述设计可以充分最大化利用CPU和GPU,从而减少资源的闲置。另外当存在多个CPU核心时,这又会涉及到CPU的并行化技术(多线程)来加速数据预处理过程,因为每个训练样本的预处理过程往往是互相独立的。关于输入流程线的优化可以参考TensorFlow官网上的Pipeline Performance Guide(https://www.tensorflow.org/performance/datasets_performance),相信你会受益匪浅。
幸运的是,最新的TensorFlow版本提供了tf.data这一套APIs来帮助我们快速实现高效又灵活的输入流水线。在TensorFlow中最常见的加载训练数据的方式是通过Feeding(https://www.tensorflow.org/api_guides/python/reading_data#Feeding)方式,其主要是定义placeholder,然后将通过Session.run()的feed_dict参数送入数据,但是这其实是最低效的加载数据方式。后来,TensorFlow增加了QueueRunner(https://www.tensorflow.org/api_guides/python/reading_data#_QueueRunner_)机制,其主要是基于文件队列以及多线程技术,实现了更高效的输入流水线,但是其APIs很是让人难懂,所以就有了现在的tf.data来替代它。
这里我们通过mnist实例来讲解如何使用tf.data建立简洁而高效的输入流水线,在介绍之前,我们先介绍如何制作TFRecords文件,这是TensorFlow支持的一种标准文件格式
1
制作TFRecords文件
TFRecords文件是TensorFlow中的标准数据格式,它是基于protobuf的二进制文件,每个TFRecord文件的基本元素是tf.train.Example,其对应的是数据集中的一个样本数据,每个Example包含Features,存储该样本的各个feature,每个feature包含一个键值对,分别对应feature的特征名与实际值。下面是一个Example实例:
// An Example for a movie recommendation application:
features {
feature {
key: "age"
value { float_list {
value: 29.0
}}
}
feature {
key: "movie"
value { bytes_list {
value: "The Shawshank Redemption"
value: "Fight Club"
}}
}
feature {
key: "movie_ratings"
value { float_list {
value: 9.0
value: 9.7
}}
}
feature {
key: "suggestion"
value { bytes_list {
value: "Inception"
}}
}
feature {
key: "suggestion_purchased"
value { float_list {
value: 1.0
}}
}
feature {
key: "purchase_price"
value { float_list {
value: 9.99
}}
}
}
上面是一个电影推荐系统中的一个样本,可以看到它共含有6个特征,每个特征都是key-value类型,key是特征名,而value是特征值,值得注意的是value其实存储的是一个list,根据数据类型共分为三种:bytes_list, float_list和int64_list,分别存储字节、浮点及整数类型(见这里:https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/core/example/feature.proto)。
作为标准数据格式,TensorFlow当然提供了创建TFRecords文件的python接口,下面我们创建mnist数据集对应的TFRecords文件。对于mnist数据集,每个Example需要存储两个feature,