tensorflow自建数据集dataset使用
1 Dataset
tensorflow中提供了两个dataset的API,一个是做一个数据源,另一个是做一个管道用来不断提取数据。
tf.data.Dataset:表示一串元素(elements),其中每个元素包含了一或多个Tensor对象。例如:在一个图片pipeline中,一个元素可以是单个训练样本,它们带有一个表示图片数据的tensors和一个label组成的pair。有两种不同的方式创建一个dataset:
创建一个source (例如:Dataset.from_tensor_slices()), 从一或多个tf.Tensor对象中构建一个dataset
应用一个transformation(例如:Dataset.batch()),从一或多个tf.data.Dataset对象上构建一个dataset
tf.data.Iterator:它提供了主要的方式来从一个dataset中抽取元素。通过Iterator.get_next() 返回的该操作会yields出Datasets中的下一个元素,作为输入pipeline和模型间的接口使用。最简单的iterator是一个“one-shot iterator”,它与一个指定的Dataset相关联,通过它来进行迭代。对于更复杂的使用,Iterator.initializer操作可以使用不同的datasets重新初始化(reinitialize)和参数化(parameterize)一个iterator ,例如,在同一个程序中通过training data和validation data迭代多次。
2、tf.data.Dataset
一般我们可以从tensor序列直接导入到Dataset中,如下几个例子,直接是tensor
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2