TensorFlow 数据读取方法总结

读取数据(Reading data)

下一篇:tf.data 官方教程

推荐:如何构建高性能的输入 pipeline

TensorFlow输入数据的方法有四种:

  • tf.data API:可以很容易的构建一个复杂的输入通道(pipeline)(首选数据输入方式)(Eager模式必须使用该API来构建输入通道)
  • Feeding:使用Python代码提供数据,然后将数据feeding到计算图中。
  • QueueRunner:基于队列的输入通道(在计算图计算前从队列中读取数据)
  • Preloaded data:用一个constant常量将数据集加载到计算图中(主要用于小数据集)

1. tf.data API

关于tf.data.Dataset的更详尽解释请看《programmer’s guide》。tf.data API能够从不同的输入或文件格式中读取、预处理数据,并且对数据应用一些变换(例如,batching、shuffling、mapping function over the dataset),tf.data API 是旧的 feeding、QueueRunner的升级。

tf.data的教程见下篇: TensorFlow导入数据(tf.data)

tf.data 到底为何物:

  • tf.data 之前,一般使用 QueueRunner,但 QueueRunner 基于 Python 的多线程及队列等,效率不够高,所以 Google发布了tf.data,其基于C++的多线程及队列,彻底提高了效率。所以不建议使用 QueueRunner 了,取而代之,使用 tf.data 模块吧:简单、高效。

2. Feeding

注意:Feeding是数据输入效率最低的方式,应该只用于小数据集和调试(debugging)

TensorFlow的Feeding机制允许我们将数据输入计算图中的任何一个Tensor。因此可以用Python来处理数据,然后直接将处理好的数据feed到计算图中 。

run()eval()中用feed_dict来将数据输入计算图:

with tf.Session():
  input = tf.placeholder(tf.float32)
  classifier = ...
  print(classifier.eval(feed_dict={
   input: my_python_preprocessing_fn()}))

虽然你可以用feed data替换任何Tensor的值(包括variables和constants),但最好的使用方法是使用一个tf.placeholder节点(专门用于feed数据)。它不用初始化,也不包含数据。一个placeholder没有被feed数据,则会报错。

使用placeholder和feed_dict的一个实例(数据集使用的是MNIST)见tensorflow/examples/tutorials/mnist/fully_connected_feed.py

3. QueueRunner

注意:这一部分介绍了基于队列(Queue)API构建输入通道(pipelines),这一方法完全可以使用 tf.data API来替代。

一个基于queue的从文件中读取records的通道(pipline)一般有以下几个步骤:

  1. 文件名列表(The list of filenames)
  2. 文件名打乱(可选)(Optional filename shuffling)
  3. epoch限制(可选)(Optional epoch limit)
  4. 文件名队列(Filename queue)
  5. 与文件格式匹配的Reader(A Reader for the file format)
  6. decoder(A decoder for a record read by the reader)
  7. 预处理(可选)(Optional preprocessing)
  8. Example队列(Example queue)

3.1 Filenames, shuffling, and epoch limits

对于文件名列表,有很多方法:1. 使用一个constant string Tensor(比如:["file0", "file1"])或者[("file%d" %i) for i in range(2)];2. 使用 tf.train.match_filenames_once 函数;3. 使用 tf.gfile.Glob(path_pattern)

将文件名列表传给 tf.train.string_input_producer 函数。string_input_producer 创建一个 FIFO 队列来保存(holding)文件名,以供Reader使用。

string_input_producer 可以对文件名进行shuffle(可选)、设置一个最大迭代 epochs 数。在每个epoch,一个queue runner将整个文件名列表添加到queue,如果shuffle=True,则添加时进行shuffle。This procedure provides a uniform sampling of files, so that examples are not under- or over- sampled relative to each other。

queue runner线程独立于reader线程,所以enqueuing和shuffle不会阻碍reader。

3.2 File formats

要选择与输入文件的格式匹配的reader,并且要将文件名队列传递给reader的 read 方法。read 方法输出一个 key identifying the file and record(在调试过程中非常有用,如果你有一些奇怪的 record)

3.2.1 CSV file

为了读取逗号分隔符分割的text文件(csv),要使用一个 tf.TextLineReader 和一个 tf.decode_csv。例如:

filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [
  • 27
    点赞
  • 107
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值