读取数据(Reading data)
TensorFlow输入数据的方式有四种:
- tf.data API :可以很容易的构建一个复杂的输入通道(pipeline)(首选数据输入方式)(Eager模式必须使用该API来构建输入通道)
- Feeding:使用Python代码提供数据,然后将数据feeding到计算图中。
- QueueRunner:基于队列的输入通道(在计算图计算前从队列中读取数据)
- Preloaded data:用一个constant常量将数据集加载到计算图中(主要用于小数据集)
**1. tf.data API **
关于tf.data.Dataset的更详尽解释请看《 [ programmer’s guide
](https://tensorflow.google.cn/programmers_guide/datasets) 》。tf.data
API能够从不同的输入或文件格式中读取、预处理数据,并且对数据应用一些变换(例如,batching、shuffling、mapping function
over the dataset),tf.data API 是旧的 feeding、QueueRunner的升级。
2. Feeding
注意:Feeding是数据输入效率最低的方式,应该只用于小数据集和调试(debugging)
TensorFlow的Feeding机制允许我们将数据输入计算图中的任何一个Tensor。因此可以用Python来处理数据,然后直接将处理好的数据feed到计算图中
。
在 run()
或 eval()
中用 feed_dict
来将数据输入计算图:
with tf.Session():
input = tf.placeholder(tf.float32)
classifier = ...
print(classifier.eval(feed_dict={
input: my_python_preprocessing_fn()}))
虽然你可以用feed data替换任何Tensor的值(包括variables和constants),但最好的使用方法是使用一个 tf.placeholder
节点(专门用于feed数据)。它不用初始化,也不包含数据。一个placeholder没有被feed数据,则会报错。
使用placeholder和feed_dict的一个实例(数据集使用的是MNIST)见 tensorflow/examples/tutorials/mnist/fully_connected_feed.py
3. QueueRunner
注意:这一部分介绍了基于队列(Queue)API构建输入通道(pipelines),这一方法完全可以使用 tf.data API来替代。
一个基于queue的从文件中读取records的通道(pipline)一般有以下几个步骤:
- 文件名列表(The list of filenames)
- 文件名打乱(可选)(Optional filename shuffling)
- epoch限制(可选)(Optional epoch limit)
- 文件名队列(Filename queue)
- 与文件格式匹配的Reader(A Reader for the file format)
- decoder(A decoder for a record read by the reader)
- 预处理(可选)(Optional preprocessing)
- Example队列(Example queue)
3.1 Filenames, shuffling, and epoch limits
对于文件名列表,有很多方法:1. 使用一个constant string Tensor(比如: ["file0", "file1"]
)或者 [("file%d" %i) for i in range(2)]
;2. 使用 tf.train.match_filenames_once
函数;3. 使用 tf.gfile.Glob(path_pattern)
。
将文件名列表传给 tf.train.string_input_producer
函数。 string_input_producer
创建一个
FIFO 队列来保存(holding)文件名,以供Reader使用。
string_input_producer
可以对文件名进行shuffle(可选)、设置一个最大迭代 epochs
数。在每个epoch,一个queue runner将整个文件名列表添加到queue,如果 shuffle=True
,则添加时进行shuffle。This procedure provides a uniform sampling of files, so that
examples are not under- or over- sampled relative to each other。
queue runner线程独立于reader线程,所以enqueuing和shuffle不会阻碍reader。
3.2 File formats
要选择与输入文件的格式匹配的reader,并且要将文件名队列传递给reader的 read 方法。read 方法输出一个 key identifying
the file and record(在调试过程中非常有用,如果你有一些奇怪的 record)
**3.2.1 CSV file
**
为了读取 逗号分隔符分割的text文件(csv), 要使用一个 tf.TextLineReader
和一个 tf.decode_csv
。例如:
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
value, record_defaults