这是本人关于tf.data
的第二篇博文,第一篇基于TF-v1详细介绍了tf.data
,但是v1和v2很多地方不兼容,所以替大家瞧瞧v2的tf.data
模块有什么新奇之处。
TensorFlow版本:2.1.0
首先贴上TF v1版本的tf.data
博文地址:《TensorFlow tf.data 导入数据(tf.data官方教程)》
文章目录
使用 tf.data
构建数据输入通道
tf.data
API编写的数据输入通道简单、并且可重用度高。tf.data
能够实现非常复杂的数据输入通道。例如:图像模型的数据输入管道可能会聚集来自分布式文件系统中文件的数据,对每个图像应用随机扰动,然后将随机选择的图像合并为一批进行训练。文本模型的数据输入管道可能涉及从原始文本数据中提取符号,将其转换为带有查找表的嵌入标识符,以及将不同长度的序列分批处理。tf.data
API使得处理大量数据,从不同数据格式读取数据以及执行复杂的转换成为可能。
tf.data
API引入了tf.data.Dataset
这个抽象概念。它是一个元素组成的序列,每个元素可以由一个或多个部分组成。例如,图像的数据输入通道中,一个元素可以是由数据和标签组成的一个训练样本。
创建dataset的方法有两种:
- 基于内存中的数据 或 硬盘中的一个或多个文件 建立
Dataset
。 - 通过对
Dataset
进行 transform 得到一个新的Dataset
。
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
np.set_printoptions(precision=4)
1. 基础知识 ¶
建立一个数据输入通道,一般需要从数据源开始。如果你的数据储存在内存中,你可以使用tf.data.Dataset.from_tensor()
或tf.data.Dataset.from_tensor_slices()
创建Dataset
。如果你的数据是TFRecord格式,你可以使用tf.data.TFRecordDataset()
创建Dataset
。
一旦你有了一个Dataset对象,你可以通过调用它的方法对其进行变换产生一个新的 Dataset对象。
Dataset是一个Python可迭代对象。所以可以使用 for 循环来消耗它的元素:
dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
<TensorSliceDataset shapes: (), types: tf.int32>
for elem in dataset:
print(elem.numpy())
8
3
0
8
2
1
或者显式使用iter
创建一个Python迭代器,并使用next
来消耗其的元素:
it = iter(dataset)
print(next(it).numpy())
8
另外,也可以使用reduce()
变换来消耗数据集的元素,根据所有元素产生单个结果。下面的示例说明如何使用reduce变换来计算整数数据集的总和。
print(dataset.reduce(0, lambda state, value: state + value).numpy())
22
1.1 Dataset
结构介绍 ¶
一个Dataset
由多个相同结构的(嵌套)元素组成,每个元素又由多个可由tf.TypeSpec
表示的部分组成(常见的有Tensor, SparseTensor, RaggedTensor, TensorArray, Dataset)。
利用Dataset.element_spec
属性可以检查每个元素的组成部分的类型。该属性返回一个由tf.TypeSpec
对象组成的嵌套结构,这个结构与Dataset中元素的结构是对应的。例如:
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))
dataset1.element_spec
TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random.uniform([4]),
tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))
dataset2.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None),
\; TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
dataset3.element_spec
(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
\; (TensorSpec(shape=(), dtype=tf.float32, name=None),
\; \; TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))
dataset4.element_spec
SparseTensorSpec(TensorShape([3, 4]), tf.int32)
# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type
tensorflow.python.framework.sparse_tensor.SparseTensor
Dataset 的变换支持任何结构的数据集。在使用 Dataset.map()
、Dataset.flat_map()
和 Dataset.filter()
函数时(这些转换会对每个元素应用一个函数),元素结构决定了函数的参数:
dataset1 = tf.data.Dataset.from_tensor_slices(
tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))
dataset1
<TensorSliceDataset shapes: (10,), types: tf.int32>
for z in dataset1:
print(z.numpy())
[6 7 1 1 5 6 7 8 7 6]
[8 3 3 7 9 3 8 4 8 4]
[2 3 6 9 4 2 1 8 1 6]
[6 7 1 9 6 2 4 7 9 1]
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random.uniform([4]),
tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))
dataset2
<TensorSliceDataset shapes: ((), (100,)), types: (tf.float32, tf.int32)>
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
dataset3
<ZipDataset shapes: ((10,), ((), (100,))), types: (tf.int32, (tf.float32, tf.int32))>
for a, (b,c) in dataset3:
print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
注:为 Dataset 中的元素的各个组件命名通常会带来便利性(例如,元素的各个组件表示不同特征时)。除了元组之外,还可以使用 命名元组(collections.namedtuple
) 或 字典 来表示 Dataset 的单个元素。
dataset = tf.data.Dataset.from_tensor_slices(
{
"a": tf.random.uniform([4]),
"b": tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)})
dataset..element_spec
{‘a’: TensorSpec(shape=(), dtype=tf.float32, name=None), ‘b’: TensorSpec(shape=(100,), dtype=tf.int32, name=None)}
2. 读取输入数据 ¶
2.1 读取Numpy数组 ¶
See Loading NumPy arrays for more examples.
如果您的数据存储在内存中,则创建 Dataset
的最简单方法是使用Dataset.from_tensor_slices()
创建dataset。
train, test = tf.keras.datasets.fashion_mnist.load_data() # out is np array
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
images, labels = train
images = images/255
dataset = tf.data.Dataset.from_tensor_slices((images, labels)) # auto convert np array to constant tensor
dataset
<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>
注意:上面的代码段会将 features 和 labels 数组作为 tf.constant()
嵌入 TensorFlow 图中。这非常适合小型数据集,但会浪费内存,因为这会多次复制数组的内容,并可能会达到 tf.GraphDef
协议缓冲区的 2GB 上限。
2.2 读取Python生成器中的数据 ¶
另一个常见的数据源是Python生成器。
注意:虽然使用Python生成器很简单,但这种方法的移植性、可扩展性较差。它必须与生成器运行在同一个Python进程中,并且它仍然受Python GIL的制约。
def count(stop):
i = 0
while i<stop:
yield i
i += 1
for n in count(5):
print(n)
0
1
2
3
4
Dataset.from_generator
可以将生成器转化为tf.data.Dataset
。.from_generator
函数将可调用对象作为输入,从而在到达生成器末尾时可重新启动生成器。它带有一个可选args参数,利用该参数可向可调用对象传递传递参数。
output_types参数是必需的,因为tf.data
会在后台构建一个tf.Graph
(图的边界需要tf.type
)。
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
print(count_batch.numpy())
[0 \, \, 1 \, \, 2 \, \, 3 \, \, 4 \, \, 5 \, \, 6 \, \, 7 \, \, 8 \, \, 9 \, ]
[10 \, 11 \, 12 \, 13 \, 14 \, 15 \, 16 \, 17 \, 18 \, 19]
[20 \, 21 \, 22 \, 23 \, 24 \, 0 \, 1 \, 2 \, 3 \, 4]
[ 5 \, 6 \, 7 \, 8 \, 9 \, 10 \, 11 \, 12 \, 13 \, 14]
[15 \, 16 \, 17 \, 18 \, 19 \, 20 \, 21 \, 22 \,