【深度学习笔记(十七)】之Tensorflow2中构造数据集对象,预处理数据以及获取使用数据集

本文章由公号【开发小鸽】发布!欢迎关注!!!


老规矩–妹妹镇楼:

一. 创建数据集对象

(一)tf.data.Dataset的组成

       tensorflow2中通过tf.data.Dataset类对数据集进行了封装。tf.data.Dataset由一系列的可迭代访问的元素组成,每个元素包含一个或多个张量。如图像数据集中,每个元素可以是形状为 长 x 宽 x 通道数 的图片张量,也可以是由图片张量和图片标签张量组成的元组。

(二)创建tf.data.Dataset对象

       当数据量较小时,可以通过tf.data.Dataset.from_tensor_slices()创建tf.data.Dataset对象,传入一个张量或者多个第0维大小相同的张量,即可按张量的第0维展开构建数据集,数据集的元素数量为张量第0维的大小。当提供多个张量作为输入时,张量的第0维大小必须相同,且必须将多个张量作为元组拼接并作为输入。

如:

import tensorflow as tf
import numpy as np
.
X = tf.constant([2013, 2014, 2015, 2016, 2017])
Y = tf.constant([12000, 14000, 15000, 16500, 17500])

dataset = tf.data.Dataset.from_tensor_slices((X, Y))

for x, y in dataset:
	print(x.numpy(), y.numpy())

(三)展示mnist_dataset数据集的图片和标签


import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
(train_data, train_label), ( test_data, test_label ) = tf.keras.datasets.mnist.load_data()
train_data = np.expand_dims(train_data.astype(np.float32) / 255., axis=-1)
mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))

#mnist_dataset = mnist_dataset.map()
for image, label in mnist_dataset:
    plt.title(label.numpy())
    plt.imshow(image.numpy()[:, :, 0])
    plt.show()

二. 数据集对象的预处理

(一) Dataset.map(fun)

       对数据集对象中的每个元素应用函数fun,得到一个新的数据集。

def rot90(image, label):
	image = tf.image.rot90(image)
	return image, label

mnist_dataset = mnist_dataset.map(rot90)

for image, label in mnist_dataset:
	plt.title(label.numpy())
	plt.imshow(image.numpy()[:, :, 0])
	plt.show()

(二) Dataset.shuffle(buffer_size)

       打乱数据集,设定一个buffer_size大小的缓冲区,取出前buffer_size个元素放到缓冲区中,再从缓冲区中随机采样,后续的数据填补空缺。当数据集的顺序分布较为随机,则缓冲区的大小可较小,否则需要大的缓冲区。

mnist_dataset = mnist_dataset.shuffle(buffer_size=10000).batch(4)

for images, labels in mnist_dataset:
	fig, axs = plt.subplots(1, 4)
	for i in range(4):
		axs[i].set_title(labels.numpy()[i])
		axs[i].imshow(images.numpy()[i, :, :, 0])
	plt.show()

(三) Dataset.batch(batch_size)

       将数据集的数据按batch_size分成多个批次,每batch_size个元素,使用tf.stack()在第0位合并为一个元素。

mnist_dataset = mnist_dataset.batch(4)

for images, labels in mnist_dataset: # image: [4, 28, 28, 1], labels: [4]
fig, axs = plt.subplots(1, 4)
for i in range(4):
	axs[i].set_title(labels.numpy()[i])
	axs[i].imshow(images.numpy()[i, :, :, 0])
plt.show()

三. 获取数据集的元素

(一) for循环

       tf.data.Dataset是一个可迭代的对象,可以使用for循环迭代获取数据。

(二) iter()

       使用iter()显示创建一个迭代器,并使用next()获取下一个元素:

dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
it = iter(dataset)
a_0, b_0, c_0, ... = next(it)
a_1, b_1, c_1, ... = next(it)

四. 使用数据集对象

       keras支持使用tf.data.Dataset作为输入,如fit()函数需要传入训练数据,标签数据,迭代次数以及每批数据的数量。

model.fit(x=train_data, y=train_label, epochs=num_epochs,
batch_size=batch_size)

       我们可以直接传入Dataset,在此之前通过Dataset.batch(batch_size)进行分批,因此,fit()函数中只需要传入Dataset以及迭代次数即可:

model.fit(mnist_dataset, epochs=num_epochs)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值