实例介绍TensorFlow的输入流水线

本文介绍了在TensorFlow中构建输入流水线的重要性,包括提取、预处理和加载数据的ETL流程。作者强调了优化输入流水线以充分利用CPU和GPU资源的必要性,并对比了不同数据加载方式的效率。文章通过mnist数据集实例,详细讲解了如何使用tf.data API创建输入流水线,包括制作和读取TFRecords文件,以及如何创建feedable iterator进行训练和验证。
摘要由CSDN通过智能技术生成

作者:叶   虎

编辑:赵一帆

前  言

在训练模型时,我们首先要处理的就是训练数据的加载与预处理的问题,这里称这个过程为输入流水线(input pipelines,或输入管道,[参考:https://www.tensorflow.org/performance/datasets_performance])。在TensorFlow中,典型的输入流水线包含三个流程(ETL流程):

  1. 提取(Extract):从存储介质(如硬盘)中读取数据,可能是本地读取,也可能是远程读取(比如在分布式存储系统HDFS)

  2. 预处理(Transform):利用CPU处理器解析和预处理提取的数据,如图像解压缩,数据扩增或者变换,然后会做random shuffle,并形成batch。

  3. 加载(load):将预处理后的数据加载到加速设备中(如GPUs)来执行模型的训练。

输入流水线对于加速模型训练还是很重要的,如果你的CPU处理数据能力跟不上GPU的处理速度,此时CPU预处理数据就成为了训练模型的瓶颈环节。除此之外,上述输入流水线本身也有很多优化的地方。比如,一个典型的模型训练过程中,CPU预处理数据时,GPU是闲置的,当GPU训练模型时,CPU是闲置的,这个过程如下所示:

图片

这样一个训练step中所花费的时间是CPU预处理数据和GPU训练模型时间的总和。显然这个过程中有资源浪费,一个改进的方法就是交叉CPU数据处理和GPU模型训练这两个过程,当GPU处于第个训练阶段,CPU正在准备第N+1步所需的数据,如下图所示:

图片

明显上述设计可以充分最大化利用CPU和GPU,从而减少资源的闲置。另外当存在多个CPU核心时,这又会涉及到CPU的并行化技术(多线程)来加速数据预处理过程,因为每个训练样本的预处理过程往往是互相独立的。关于输入流程线的优化可以参考TensorFlow官网上的Pipeline Performance Guide(https://www.tensorflow.org/performance/datasets_performance),相信你会受益匪浅。

幸运的是,最新的TensorFlow版本提供了tf.data这一套APIs来帮助我们快速实现高效又灵活的输入流水线。在TensorFlow中最常见的加载训练数据的方式是通过Feeding(https://www.tensorflow.org/api_guides/python/reading_data#Feeding)方式,其主要是定义placeholder,然后将通过Session.run()的feed_dict参数送入数据,但是这其实是最低效的加载数据方式。后来,TensorFlow增加了QueueRunner(https://www.tensorflow.org/api_guides/python/reading_data#_QueueRunner_)机制,其主要是基于文件队列以及多线程技术,实现了更高效的输入流水线,但是其APIs很是让人难懂,所以就有了现在的tf.data来替代它。

这里我们通过mnist实例来讲解如何使用tf.data建立简洁而高效的输入流水线,在介绍之前,我们先介绍如何制作TFRecords文件,这是TensorFlow支持的一种标准文件格式

1

制作TFRecords文件

TFRecords文件是TensorFlow中的标准数据格式,它是基于protobuf的二进制文件,每个TFRecord文件的基本元素是tf.train.Example,其对应的是数据集中的一个样本数据,每个Example包含Features,存储该样本的各个feature,每个feature包含一个键值对,分别对应feature的特征名与实际值。下面是一个Example实例:

// An Example for a movie recommendation application:
       features {
         feature {
           key: "age"
           value { float_list {
             value: 29.0
           }}
         }
         feature {
           key: "movie"
           value { bytes_list {
             value: "The Shawshank Redemption"
             value: "Fight Club"
           }}
         }
         feature {
           key: "movie_ratings"
           value { float_list {
             value: 9.0
             value: 9.7
           }}
         }
         feature {
           key: "suggestion"
           value { bytes_list {
             value: "Inception"
           }}
         }
         feature {
           key: "suggestion_purchased"
           value { float_list {
             value: 1.0
           }}
        }
         feature {
           key: "purchase_price"
           value { float_list {
             value: 9.99
           }}
         }
      }

上面是一个电影推荐系统中的一个样本,可以看到它共含有6个特征,每个特征都是key-value类型,key是特征名,而value是特征值,值得注意的是value其实存储的是一个list,根据数据类型共分为三种:bytes_list, float_listint64_list,分别存储字节、浮点及整数类型(见这里:https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/core/example/feature.proto)。

作为标准数据格式,TensorFlow当然提供了创建TFRecords文件的python接口,下面我们创建mnist数据集对应的TFRecords文件。对于mnist数据集,每个Example需要存储两个feature,

根据提供的引用内容,您所提到的"DNN拟合代码实例Tensorflow"是指使用Tensorflow库来实现深度神经网络(DNN)进行拟合的代码示例。 在引用中提到了使用了TensorFlow.NET和SciSharp.Models.TimeSeries这两个NuGet包。TensorFlow.NET是一个TensorFlow的.NET绑定库,而SciSharp.Models.TimeSeries是一个用于时间序列数据建模的科学计算库。 根据引用的描述,这个代码示例可能是作者对字符识别问题进行研究的学习笔记,并且代码主要使用VB.NET编写,因为作者在工程开发中习惯使用VB.NET。需要注意的是,VB.NET和C#在语法和Tensorflow.NET库的使用上可能存在一些差异。 引用中提到,作者尝试使用多层感知机(DNN的一种)来改进模型的拟合效果,并表示在之前的人工神经网络代码中测试过类似的东西,并且取得了较好的准确率。 最后,在引用中提到了一个问题,即在Tensorflow中未设置随机种子,但可以通过设置tf.set_random_seed()来固定随机种子。 综上所述,如果您想寻找关于使用Tensorflow库实现DNN拟合的代码实例,您可以参考TensorFlow.NET和SciSharp.Models.TimeSeries这两个包的文档和示例代码,尤其是作者所提到的SciSharp-Stack-Examples-master项目中的范例。同时,需要注意VB.NET和C#语法的差异以及Tensorflow.NET库的使用方式。<span class="em">1</span><span class="em">2</span><span class="em">3</span><span class="em">4</span>
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值