Tensorflow2.0入门教程10:tf.data.Dataset使用介绍

很多时候,我们希望使用自己的数据集来训练模型。然而,面对一堆格式不一的原始数据文件,将其预处理并读入程序的过程往往十分繁琐,甚至比模型的设计还要耗费精力。比如,为了读入一批图像文件,我们可能需要纠结于 python 的各种图像处理包(比如 pillow ),自己设计 Batch 的生成方式,最后还可能在运行的效率上不尽如人意。为此,TensorFlow 提供了 tf.data 这一模块,包括了一套灵活的数据集构建 API,能够帮助我们快速、高效地构建数据输入的流水线,尤其适用于数据量巨大的场景。

tf.data数据集对象的建立

tf.data 的核心是 tf.data.Dataset 类,提供了对数据集的高层封装。tf.data.Dataset 由一系列的可迭代访问的元素(element)组成,每个元素包含一个或多个张量。比如说,对于一个由图像组成的数据集,每个元素可以是一个形状为 长×宽×通道数 的图片张量,也可以是由图片张量和图片标签张量组成的元组(Tuple)。

最基础的建立 tf.data.Dataset 的方法是使用 tf.data.Dataset.from_tensor_slices() ,适用于数据量较小(能够整个装进内存)的情况。具体而言,如果我们的数据集中的所有元素通过张量的第 0 维,拼接成一个大的张量(例如,前面的 MNIST 数据集的训练集即为一个 [60000, 28, 28] 的张量,表示了 60000 张 28*28 的单通道灰度图像),那么我们提供一个这样的张量或者第 0 维大小相同的多个张量作为输入,即可按张量的第 0 维展开来构建数据集,数据集的元素数量为张量第 0 位的大小。具体示例如下:

import numpy as np
import tensorflow as tf

一、创建Dataset

tf.data.Dataset 是一个 Python 的可迭代对象,因此可以使用 For 循环迭代获取数据

也可以使用 iter() 显式创建一个 Python 迭代器并使用 next() 获取下一个元素,即:

X = tf.constant([1,2,3,4,5,6])
Y = tf.constant([10,20,30,40,50,60])
# 也可以使用NumPy数组,效果相同
# X = np.array([1,2,3,4,5,6])
# Y = np.array([10,20,30,40,50,60])
dataset = tf.data.Dataset.from_tensor_slices((X, Y))
for d in dataset:
    print(d)
    #(input,label)
(<tf.Tensor: id=9, shape=(), dtype=int32, numpy=1>, <tf.Tensor: id=10, shape=(), dtype=int32, numpy=10>)
(<tf.Tensor: id=11, shape=(), dtype=int32, numpy=2>, <tf.Tensor: id=12, shape=(), dtype=int32, numpy=20>)
(<tf.Tensor: id=13, shape=(), dtype=int32, numpy=3>, <tf.Tensor: id=14, shape=(), dtype=int32, numpy=30>)
(<tf.Tensor: id=15, shape=(), dtype=int32, numpy=4>, <tf.Tensor: id=16, shape=(), dtype=int32, numpy=40>)
(<tf.Tensor: id=17, shape=(), dtype=int32, numpy=5>, <tf.Tensor: id=18, shape=(), dtype=int32, numpy=50>)
(<tf.Tensor: id=19, shape=(), dtype=int32, numpy=6>, <tf.Tensor: id=20, shape=(), dtype=int32, numpy=60>)
it = iter(dataset)
next(it)
(<tf.Tensor: id=54, shape=(), dtype=int32, numpy=1>,
 <tf.Tensor: id=55, shape=(), dtype=int32, numpy=10>)
a = np.random.uniform(size=(5,2))
dataset = tf.data.Dataset.from_tensor_slices(a)
for d in dataset:
    print(d)
tf.Tensor([0.94332744 0.9150543 ], shape=(2,), dtype=float64)
tf.Tensor([0.8328214  0.30535439], shape=(2,), dtype=float64)
tf.Tensor([0.01009141 0.04326913], shape=(2,), dtype=float64)
tf.Tensor([0.23589543 0.34819133], shape=(2,), dtype=float64)
tf.Tensor([0.94740998 0.27282243], shape=(2,), dtype=float64)

载入Mnist数据集

import matplotlib.pyplot as plt 
%matplotlib inline
(train_data, train_label), (_, _) = tf.keras.datasets.mnist.load_data()
mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
for image, label in mnist_dataset:
    plt.title(label.numpy())
    plt.imshow(image.numpy()[:, :])
    plt.show()

在这里插入图片描述
在这里插入图片描述
<matplotlib.figure.Figure at 0x1e1e1ddacc0>

二、数据集对象的预处理

tf.data.Dataset 类为我们提供了多种数据集预处理方法。最常用的如:

  • Dataset.map(f) :对数据集中的每个元素应用函数 f ,得到一个新的数据集(这部分往往结合 tf.io 进行读写和解码文件, tf.image 进行图像处理);

  • Dataset.shuffle(buffer_size) :将数据集打乱(设定一个固定大小的缓冲区(Buffer),取出前 buffer_size 个元素放入,并从缓冲区中随机采样,采样后的数据用后续数据替换);

  • Dataset.batch(batch_size) :将数据集分成批次;

  • Dataset.repeat():重复数据集的元素,epoch

2.1 map

Dataset.map()将数组中元素加1

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4])
dataset = dataset.map(lambda x:x+1)
for d in dataset:
    print(d)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(5, shape=(), dtype=int32)

使用 Dataset.map() 将所有图片旋转 90 度

mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
def rot90(image, label):
    image = tf.image.rot90(image)
    return image, label
(train_data, train_label), (_, _) = tf.keras.datasets.mnist.load_data()
train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1) 
mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
mnist_dataset = mnist_dataset.map(rot90)
for image, label in mnist_dataset:
    plt.title(label.numpy())
    plt.imshow(image.numpy()[:, :, 0])
    plt.show()

在这里插入图片描述
在这里插入图片描述

2.2 batch

每次产生2个数据

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6])
dataset = dataset.batch(4)
for d in dataset:
    print(d)
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
tf.Tensor([5 6], shape=(2,), dtype=int32)

使用 Dataset.batch() 将数据集划分批次,每个批次的大小为 4:

(train_data, train_label), (_, _) = tf.keras.datasets.mnist.load_data()
train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1) 
mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
mnist_dataset = mnist_dataset.batch(4)
for images, labels in mnist_dataset:    # image: [4, 28, 28, 1], labels: [4]
    fig, axs = plt.subplots(1, 4)
    for i in range(4):
        axs[i].set_title(labels.numpy()[i])
        axs[i].imshow(images.numpy()[i, :, :, 0])
    plt.show()

在这里插入图片描述
在这里插入图片描述
<matplotlib.figure.Figure at 0x1e1e2ed9940>

2.3 shuffle

使用 Dataset.shuffle():将数据打散后再设置批次,shuffle的功能为打乱dataset中的元素, 它会维持一个固定大小的buffer,并从该buffer中随机均匀地选择下一个元素,参数buffer_size建议设为样本数量,过大会浪费内存空间,过小会导致打乱不充分。

例如我们将batch-size设为2,那么每次iterator都会得到2个数据

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6])
dataset = dataset.shuffle(2)
for d in dataset:
    print(d)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(5, shape=(), dtype=int32)
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)

2.4 repeat

repeat:repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6])
dataset = dataset.repeat(2)
dataset = dataset.shuffle(4)
# dataset = dataset.batch(2)
for d in dataset:
    print(d)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(5, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(5, shape=(), dtype=int32)
tf.Tensor(6, shape=(), dtype=int32)

三、使用 tf.data 的并行化策略提高训练流程效率:

tf.data 的数据集对象为我们提供了 Dataset.prefetch() 方法,使得我们可以让数据集对象 Dataset 在训练时预取出若干个元素,使得在 GPU 训练的同时 CPU 可以准备数据,从而提升训练流程的效率。

mnist_dataset = mnist_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

在map()函数中,还有个很重要的参数num_parallel_calls,可以将数据加载与变换过程并行到多个CPU线程上。由于python语言本身的全局解释锁,想要实现真正的并行计算是非常困难的,所以这个参数实际上非常实用,通常的使用情景是网络训练时,GPU做模型运算的同时CPU加载数据。 还可以直接设置num_parallel_calls=tf.data.experimental.AUTOTUNE,这样会自动设置为最大的可用线程数,机器算力拉满。

mnist_dataset = mnist_dataset.map(map_func=rot90, num_parallel_calls=tf.data.experimental.AUTOTUNE)

四、总结:使用tf.data.Dataset.from_tensor_slices加载数据集

1: 准备要加载的数据(numpy,tensor)

2: 使用tf.data.Dataset.from_tensor_slices() 函数进行加载

3: 使用shuffle()打乱数据

4: 使用map()函数进行预处理

5: 使用batch()函数设置 batch size 值

6: 根据需要使用repeat()设置是否循环迭代数据集

Keras 支持使用 tf.data.Dataset 直接作为输入。当调用 tf.keras.Model 的 fit() 和 evaluate() 方法时,可以将参数中的输入数据 x 指定为一个元素格式为 (输入数据, 标签数据) 的 Dataset ,并忽略掉参数中的标签数据 y 。例如,对于上述的 MNIST 数据集,常规的 Keras 训练方式是:

model.fit(x=train_data, y=train_label, epochs=num_epochs, batch_size=batch_size)

使用 tf.data.Dataset 后,我们可以直接传入 Dataset :

model.fit(mnist_dataset, epochs=num_epochs)
  • 4
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值