tf.data官方教程 - - 基于TF-v2

这是本人关于tf.data的第二篇博文,第一篇基于TF-v1详细介绍了tf.data,但是v1和v2很多地方不兼容,所以替大家瞧瞧v2的tf.data模块有什么新奇之处。

TensorFlow版本:2.1.0

首先贴上TF v1版本的tf.data博文地址:《TensorFlow tf.data 导入数据(tf.data官方教程)


使用 tf.data 构建数据输入通道

tf.data API编写的数据输入通道简单、并且可重用度高。tf.data能够实现非常复杂的数据输入通道。例如:图像模型的数据输入管道可能会聚集来自分布式文件系统中文件的数据,对每个图像应用随机扰动,然后将随机选择的图像合并为一批进行训练。文本模型的数据输入管道可能涉及从原始文本数据中提取符号,将其转换为带有查找表的嵌入标识符,以及将不同长度的序列分批处理。tf.dataAPI使得处理大量数据,从不同数据格式读取数据以及执行复杂的转换成为可能。

tf.data API引入了tf.data.Dataset 这个抽象概念。它是一个元素组成的序列,每个元素可以由一个或多个部分组成。例如,图像的数据输入通道中,一个元素可以是由数据和标签组成的一个训练样本。

创建dataset的方法有两种:

  • 基于内存中的数据 或 硬盘中的一个或多个文件 建立Dataset
  • 通过对Dataset进行 transform 得到一个新的Dataset
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import pathlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

np.set_printoptions(precision=4)

1. 基础知识

建立一个数据输入通道,一般需要从数据源开始。如果你的数据储存在内存中,你可以使用tf.data.Dataset.from_tensor()tf.data.Dataset.from_tensor_slices()创建Dataset。如果你的数据是TFRecord格式,你可以使用tf.data.TFRecordDataset()创建Dataset

一旦你有了一个Dataset对象,你可以通过调用它的方法对其进行变换产生一个新的 Dataset对象。

Dataset是一个Python可迭代对象。所以可以使用 for 循环来消耗它的元素:

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset

<TensorSliceDataset shapes: (), types: tf.int32>

for elem in dataset:
  print(elem.numpy())

8
3
0
8
2
1

或者显式使用iter创建一个Python迭代器,并使用next来消耗其的元素:

it = iter(dataset)

print(next(it).numpy())

8

另外,也可以使用reduce()变换来消耗数据集的元素,根据所有元素产生单个结果。下面的示例说明如何使用reduce变换来计算整数数据集的总和。

print(dataset.reduce(0, lambda state, value: state + value).numpy())

22

1.1 Dataset 结构介绍

一个Dataset由多个相同结构的(嵌套)元素组成,每个元素又由多个可由tf.TypeSpec表示的部分组成(常见的有Tensor, SparseTensor, RaggedTensor, TensorArray, Dataset)。

利用Dataset.element_spec属性可以检查每个元素的组成部分的类型。该属性返回一个由tf.TypeSpec对象组成的嵌套结构,这个结构与Dataset中元素的结构是对应的。例如:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))

dataset1.element_spec

TensorSpec(shape=(10,), dtype=tf.float32, name=None)

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2.element_spec

(TensorSpec(shape=(), dtype=tf.float32, name=None),
   \; TensorSpec(shape=(100,), dtype=tf.int32, name=None))

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3.element_spec

(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
   \; (TensorSpec(shape=(), dtype=tf.float32, name=None),
   \;    \; TensorSpec(shape=(100,), dtype=tf.int32, name=None)))

# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))

dataset4.element_spec

SparseTensorSpec(TensorShape([3, 4]), tf.int32)

# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type

tensorflow.python.framework.sparse_tensor.SparseTensor

Dataset 的变换支持任何结构的数据集。在使用 Dataset.map()Dataset.flat_map()Dataset.filter() 函数时(这些转换会对每个元素应用一个函数),元素结构决定了函数的参数:

dataset1 = tf.data.Dataset.from_tensor_slices(
    tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))

dataset1

<TensorSliceDataset shapes: (10,), types: tf.int32>

for z in dataset1:
  print(z.numpy())

[6 7 1 1 5 6 7 8 7 6]
[8 3 3 7 9 3 8 4 8 4]
[2 3 6 9 4 2 1 8 1 6]
[6 7 1 9 6 2 4 7 9 1]

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2

<TensorSliceDataset shapes: ((), (100,)), types: (tf.float32, tf.int32)>

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3

<ZipDataset shapes: ((10,), ((), (100,))), types: (tf.int32, (tf.float32, tf.int32))>

for a, (b,c) in dataset3:
  print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))

shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

:为 Dataset 中的元素的各个组件命名通常会带来便利性(例如,元素的各个组件表示不同特征时)。除了元组之外,还可以使用 命名元组(collections.namedtuple) 或 字典 来表示 Dataset 的单个元素。

dataset = tf.data.Dataset.from_tensor_slices(
   {
   "a": tf.random.uniform([4]),
    "b": tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)})

dataset..element_spec

{‘a’: TensorSpec(shape=(), dtype=tf.float32, name=None), ‘b’: TensorSpec(shape=(100,), dtype=tf.int32, name=None)}

2. 读取输入数据

2.1 读取Numpy数组

See Loading NumPy arrays for more examples.

如果您的数据存储在内存中,则创建 Dataset 的最简单方法是使用Dataset.from_tensor_slices()创建dataset。

train, test = tf.keras.datasets.fashion_mnist.load_data() # out is np array

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels)) # auto convert np array to constant tensor
dataset

<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>

注意:上面的代码段会将 features 和 labels 数组作为 tf.constant() 嵌入 TensorFlow 图中。这非常适合小型数据集,但会浪费内存,因为这会多次复制数组的内容,并可能会达到 tf.GraphDef 协议缓冲区的 2GB 上限。

2.2 读取Python生成器中的数据

另一个常见的数据源是Python生成器。

注意:虽然使用Python生成器很简单,但这种方法的移植性、可扩展性较差。它必须与生成器运行在同一个Python进程中,并且它仍然受Python GIL的制约。

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
for n in count(5):
  print(n)

0
1
2
3
4

Dataset.from_generator可以将生成器转化为tf.data.Dataset.from_generator函数将可调用对象作为输入,从而在到达生成器末尾时可重新启动生成器。它带有一个可选args参数,利用该参数可向可调用对象传递传递参数。

output_types参数是必需的,因为tf.data会在后台构建一个tf.Graph(图的边界需要tf.type)。

ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())

[0   \,   \, 1   \,   \, 2   \,   \, 3   \,   \, 4   \,   \, 5   \,   \, 6   \,   \, 7   \,   \, 8   \,   \, 9   \, ]
[10   \, 11   \, 12   \, 13   \, 14   \, 15   \, 16   \, 17   \, 18   \, 19]
[20   \, 21   \, 22   \, 23   \, 24   \, 0   \, 1   \, 2   \, 3   \, 4]
[ 5   \, 6   \, 7   \, 8   \, 9   \, 10   \, 11   \, 12   \, 13   \, 14]
[15   \, 16   \, 17   \, 18   \, 19   \, 20   \, 21   \, 22   \,

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进
05-26
以下是对代码的改进建议: 1. 在代码开头添加注释,简要说明代码功能和使用方法。 2. 将导入模块的语句放在代码开头。 3. 将模型保存路径和评估时间间隔定义为常量,并使用有意义的变量名。 4. 将计算正确率和加载模型的过程封装为函数。 5. 在主函数中调用评估函数。 改进后的代码如下: ``` # 该代码实现了使用已训练好的模型对 MNIST 数据集进行评估 import time import tensorflow.compat.v1 as tf from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train # 定义常量 MODEL_SAVE_PATH = 'model/' EVAL_INTERVAL_SECS = 10 def evaluate(mnist): """ 计算模型在验证集上的正确率 """ with tf.Graph().as_default() as g: # 定义输入和输出格式 x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') # 直接调用封装好的函数计算前向传播结果 y = mnist_inference.inference(x, None) # 计算正确率 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 加载模型 variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY) variables_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variables_to_restore) # 在验证集上计算正确率 with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict={x: mnist.validation.images, y_: mnist.validation.labels}) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') def main(argv=None): # 读取数据集 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # 每隔一定时间评估模型在验证集上的正确率 while True: evaluate(mnist) time.sleep(EVAL_INTERVAL_SECS) if __name__ == '__main__': tf.app.run() ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值