tf.data API 引入了一个 tf.data.Dataset 抽象,它表示一个元素序列,其中每个元素都由一个或多个组件组成。
tf.data可以读取多种文件,然后生成数据:文本文件、tfrecords等
dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
dataset = tf.data.Dataset.list_files("/path/*.txt")
使用tf.data.Dataset.from_tensor_slices((x,y))生成数据集,传入一个元组;
dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
for item in dataset:
print(item.numpy())
读入数据:
(1)使用numpy数组:
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset
注:上面的代码段会将 features
和 labels
数组作为 tf.constant() 运算嵌入到 TensorFlow 计算图中。这对于小数据集来说效果很好,但是会浪费内存(因为数组的内容会被多次复制),并且可能会达到 tf.GraphDef
协议缓冲区的 2GB 上限。
(2)使用python生成器:
def count(stop):
i = 0
while i<stop:
yield i
i += 1
"""
for n in count(5):
print(n)
"""
#定义生成器:
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
print(count_batch.numpy())
Dataset.from_generator 构造函数会将 Python 生成器转换为具有完整功能的 tf.data.Dataset。
构造函数会获取可调用对象作为输入,而非迭代器。这样,构造函数结束后便可重启生成器。构造函数会获取一个可选的 args
参数,作为可调用对象的参数。
output_types
参数是必需的,因为 tf.data 会在内部构建 tf.Graph,而计算图边缘需要 tf.dtype
。output_shapes
参数虽然不是必需的,但强烈建议添加,因为许多 TensorFlow 运算不支持秩未知的张量。如果特定轴的长度未知或可变,请在 output_shapes
中将其设置为 None
。
(3):使用TFRecord:
tf.data API 支持多种文件格式,因此可以处理不适合存储在内存中的大型数据集。例如,TFRecord 文件格式是一种简单的、面向记录的二进制格式,许多 TensorFlow 应用都将其用于训练数据。您可以利用 tf.data.TFRecordDataset 类将一个或多个 TFRecord 文件的内容作为输入流水线的一部分进行流式传输。