参考 基于tensorflow的图像处理(四) 数据集处理 - 云+社区 - 腾讯云
除队列以外,tensorflow还提供了一套更高的数据处理框架。在新的框架中,每一个数据来源被抽象成一个“数据集”,开发者可以以数据集为基本对象,方便地进行batching、随机打乱(shuffle)等操作。
一、数据集的基本使用方法
在数据集框架中,每一个数据集代表一个数据来源:数据可能来自一个张量,一个TFRecord文件,一个文本文件,或者经过sharding的一系列文件,等等。由于训练数据集通常无法全部写入内存中,从数据中读取数据时需要使用一个迭代器(iterator)按顺序进行读取,这点与队列的dequeue()操作和Reader的read()操作相似。与队列相似,数据集也是计算图上的一个点。
下面先看一个简单的例子,这个例子从一个张量创建一个数据集,遍历这个数据集,并对每个输入输出y=x^2的值。
import tensorflow as tf
# 从一个数组创建数据集。
input_data = [1, 2, 3, 4, 5, 6]
dataset = tf.data.Dataset.from_tensor_slices(input_data)
# 定义一个迭代器用于遍历数据集。因为上面定义的数据集没有用placeholder
# 作为输入参数,所以这里可以使用最简单的one_shot_iterator。
iterator = dataset.make_one_shot_iterator()
# get_next() 返回代表一个输入数据的张量,类似于队列中dequeue()。
x = iterator.get_next()
y = x * x
with tf.Session() as sess:
for i in range(len(input_data)):
print(sess.run(y))
输出:
---
1
4
9
16
25
36
---
从以上例子可以看到,利用数据集读取数据有三个基本步骤。
1.定义数据集的构造方法
这个例子使用了tf.data.Dataset.from_tensor_slice(),表明数据集是从一个张量中构建的。如果数据集是从文件中构建的,则需要相应调用不同的构造方法。
2.定义遍历器
这个例子使用了最简单的one_shot_iterator来遍历数据集。
3.使用get_next()方法从遍历器中读取数据张量,作为计算图其他部分的输入
在图像相关任务中,输入数据通常以TFRecord形式存储,这时可以用TFRecordDataset来读取数据。与文本文件不同, 每一个TFRecord都有自己不同的feature格式,因此在读取TFRecord时,需要提供一个parser函数来解析所读取的TFRecord的数据格式。
import tensorflow as tf
# 解析一个TFRecord的方法。record是从文件中读取的一个样例。
def parser(record):
# 解析读入的一个样例
features = tf.parse_single_example(
record,
features={
'feat1': tf.FixedLenFeature([], tf.int64),
'feat2': tf.FixedLenFeature([], tf.int64),
})
return features['feat1'], features['feat2']
# 从TFRecord文件创建数据集
input_files = ["/path/to/input_file1", "/path/to/input_fi