10.6全新的数据读取方式——Dataset API
Dataset API 是Tensorflow中一种全新的数据读取方式,它可以用简单复用的方式构建复杂的Input Pipeline。(Pipeline意为流水线)
Dataset API主要用于数据读取,构建输入数据的pipeline等,Datset API可以很方便地以不同的数据格式处理大量数据及复杂的转换。
10.6.1 Dataset API架构
DataSet API中最重要的两个基础类:Dataset和Iterator。Dataset可以看作是相同类型的“元素”的有序列表。在实际使用时,单个“元素”可以是向量,也可以是字符串、图片,甚至是tuple或者dict。
如何将这个dataset中的元素取出呢?方法是从Dataset中实例化一个Iterator(迭代器,作用是遍历所有的元素),然后对Iterator进行迭代。
下面我们对Dataset API中的两个基础类做进一步说明。
1.tf.data.Dataset
tf.data.Dataset表示一串元素(elements),其中每个元素包含了一个或者多个Tensor对象(张量,多维数组)。例如:在一个图片Pipeline中,一个元素可以是单个训练样本,该样本带有一个表示图像数据的tensors和一个label组成的数据对。有两种不同的方式创建一个dataset:
1)创建一个source,从一个或者多个tf.Tensor对象中构建一个dataset。
2)拥有Dataset对象以后,应用一个transformation,可以将它们转化为新的Dataset。
最常使用Dataset的方式是使用一个迭代器
2.tf.data.Iterator
它提供的主要方式是从一个dataset中抽取元素。
10.6.2 构建Dataset
构建Dataset方法很多,常用的几种方法如下。
1.tf.data.Dataset.from_tensor_slices()
利用tf.data.Dataset.from_tensor_slices()从一个或多个tf.Tensor对象中构建一个dataset,其tf.Tensor对象中包括数组、矩阵、字典、元组等,具体实例如下:
import tensorflow as tf
import numpy as np
arry1=np.array([1.0,2.0,3.0,4.0,5.0])
dataset=tf.data.Dataset.from_tensor_slices(arry1)#将array数组转化为tensor对象的数据集
iterator=dataset.make_one_shot_iterator()#生成实例,即创建迭代器
#从iterator里取出一个元素
one_element=iterator.get_next()#get_next()函数用于返回下一个tensor对象
with tf.Session() as sess:
for i in range(len(arry1)):
print(sess.run(one_element))
2.Dataset的转换(transformations)
Datasets支持任何结构,当使用Dataset.map(),Dataset.flat_map(),以及Dataseet.filter()进行转换时,它们会对每个元素应用一个函数,元素结构决定了函数的参数。
Dataset支持一类特殊的操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。通常我们可以通过Transformation完成数据变换或预处理。
以下是一些简单示例。
(1)使用map()
import tensorflow as tf
import numpy as np
arry1=np.array([1.0,2.0,3.0,4.0,5.0])
dataset=tf.data.Dataset.from_tensor_slices(arry1)#将array数组转化为tensor对象的数据集
iterator=dataset.make_one_shot_iterator()#生成实例,即创建迭代器
#从iterator里取出一个元素
one_element=iterator.get_next()#get_next()函数用于返回下一个tensor对象
with tf.Session() as sess:
for i in range(len(arry1)):
print(sess.run(one_element))
(2)使用flat_map()、filter()等
#使用‘Dataset.flat_map()’将每个文件转换为一个单独的嵌套数据集
#然后将它们的内容顺序连接成一个单一的“扁平”数据集
#跳过第一行(标题行)
#过滤以“#”开头的行
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.flat_map(
lambda filename: (
tf.data.TextLineDataset(filename)
.skip(1)
.filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))