注意,在TensorFlow 1.3中,Dataset API是放在contrib包中的:
tf.contrib.data.Dataset
而在TensorFlow 1.4中,Dataset API已经从contrib包中移除,变成了核心API的一员:
tf.data.Dataset
此前,在TensorFlow中读取数据一般有两种方法:
使用placeholder读内存中的数据
使用queue读硬盘中的数据
Dataset API同时支持从内存和硬盘的读取,相比之前的两种方法在语法上更加简洁易懂。此外,如果想要用到TensorFlow新出的Eager模式,就必须要使用Dataset API来读取数据。
三、基本使用
1、一维数据集示范基本使用
Google官方给出的Dataset API中的类图:
在初学时,我们只需要关注两个最重要的基础类:Dataset和Iterator。
Dataset可以看作是相同类型“元素”的有序列表。在实际使用时,单个“元素”可以是向量,也可以是字符串、图片,甚至是tuple或者dict。
数据集对象实例化:
dataset = tf.data.Dataset.from_tensor_slices(数据)
迭代器对象实例化(非Eager模式下):
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
综合起来效果如下,
1 2 3 4 5 6 7 8 9 |
|
输出:1.0 2.0 3.0 4.0 5.0
读取结束异常:
如果一个dataset中元素被读取完了,再尝试sess.run(one_element)的话,就会抛出tf.errors.OutOfRangeError异常,这个行为与使用队列方式读取数据的行为是一致的。
在实际程序中,可以在外界捕捉这个异常以判断数据是否读取完,综合以上三点请参考下面的代码:
1 2 3 4 5 6 7 8 9 10 |
|
输出:1.0 2.0 3.0 4.0 5.0 end!
2、高维数据集使用
tf.data.Dataset.from_tensor_slices真正作用是切分传入Tensor的第一个维度,生成相应的dataset,即第一维表明数据集中数据的数量,之后切分batch等操作都以第一维为基础。
dataset = tf.data.Dataset.from_tensor_slices(np.random.uniform(size=(5, 2)))
传入的数值是一个矩阵,它的形状为(5, 2),tf.data.Dataset.from_tensor_slices就会切分它形状上的第一个维度,最后生成的dataset中一个含有5个元素,每个元素的形状是(2, ),即每个元素是矩阵的一行。
1 2 3 4 5 6 7 8 9 10 |
|
[0.09787406 0.71672957] [0.25681324 0.81974072] [0.35186046 0.39362398] [0.75228199 0.6534702 ] [0.39695169 0.9341708 ] end!
3、字典使用
在实际使用中,我们可能还希望Dataset中的每个元素具有更复杂的形式,如每个元素是一个Python中的元组,或是Python中的词典。例如,在图像识别问题中,一个元素可以是{“image”: image_tensor, “label”: label_tensor}的形式,这样处理起来更方便,
注意,image_tensor、label_tensor和上面的高维向量一致,第一维表示数据集中数据的数量。相较之下,字典中每一个key值可以看做数据的一个属性,value则存储了所有数据的该属性值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
|
{'a': 1.0, 'b': array([0.31721037, 0.33378767])} {'a': 2.0, 'b': array([0.99221946, 0.65894961])} {'a': 3.0, 'b': array([0.98405468, 0.11478854])} {'a': 4.0, 'b': array([0.95311317, 0.57432678])} {'a': 5.0, 'b': array([0.46067428, 0.19716722])} end!
4、复杂的tuple组合数据
类似的