第十章TensorFlow图像处理(二)Dataset API

10.6全新的数据读取方式——Dataset API

Dataset API 是Tensorflow中一种全新的数据读取方式,它可以用简单复用的方式构建复杂的Input Pipeline。(Pipeline意为流水线)
Dataset API主要用于数据读取,构建输入数据的pipeline等,Datset API可以很方便地以不同的数据格式处理大量数据及复杂的转换。

10.6.1 Dataset API架构

DataSet API中最重要的两个基础类:Dataset和Iterator。Dataset可以看作是相同类型的“元素”的有序列表。在实际使用时,单个“元素”可以是向量,也可以是字符串、图片,甚至是tuple或者dict。

如何将这个dataset中的元素取出呢?方法是从Dataset中实例化一个Iterator(迭代器,作用是遍历所有的元素),然后对Iterator进行迭代。

下面我们对Dataset API中的两个基础类做进一步说明。

1.tf.data.Dataset
tf.data.Dataset表示一串元素(elements),其中每个元素包含了一个或者多个Tensor对象(张量,多维数组)。例如:在一个图片Pipeline中,一个元素可以是单个训练样本,该样本带有一个表示图像数据的tensors和一个label组成的数据对。有两种不同的方式创建一个dataset:
1)创建一个source,从一个或者多个tf.Tensor对象中构建一个dataset。
2)拥有Dataset对象以后,应用一个transformation,可以将它们转化为新的Dataset。
最常使用Dataset的方式是使用一个迭代器
2.tf.data.Iterator
它提供的主要方式是从一个dataset中抽取元素。

10.6.2 构建Dataset

构建Dataset方法很多,常用的几种方法如下。
1.tf.data.Dataset.from_tensor_slices()
利用tf.data.Dataset.from_tensor_slices()从一个或多个tf.Tensor对象中构建一个dataset,其tf.Tensor对象中包括数组、矩阵、字典、元组等,具体实例如下:

import tensorflow as tf
import numpy as np

arry1=np.array([1.0,2.0,3.0,4.0,5.0])
dataset=tf.data.Dataset.from_tensor_slices(arry1)#将array数组转化为tensor对象的数据集
iterator=dataset.make_one_shot_iterator()#生成实例,即创建迭代器
#从iterator里取出一个元素
one_element=iterator.get_next()#get_next()函数用于返回下一个tensor对象
with tf.Session() as sess:
    for i in range(len(arry1)):
        print(sess.run(one_element))

在这里插入图片描述
2.Dataset的转换(transformations)
Datasets支持任何结构,当使用Dataset.map(),Dataset.flat_map(),以及Dataseet.filter()进行转换时,它们会对每个元素应用一个函数,元素结构决定了函数的参数。

Dataset支持一类特殊的操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。通常我们可以通过Transformation完成数据变换或预处理。

以下是一些简单示例。

(1)使用map()

import tensorflow as tf
import numpy as np

arry1=np.array([1.0,2.0,3.0,4.0,5.0])
dataset=tf.data.Dataset.from_tensor_slices(arry1)#将array数组转化为tensor对象的数据集
iterator=dataset.make_one_shot_iterator()#生成实例,即创建迭代器
#从iterator里取出一个元素
one_element=iterator.get_next()#get_next()函数用于返回下一个tensor对象
with tf.Session() as sess:
    for i in range(len(arry1)):
        print(sess.run(one_element))

在这里插入图片描述
(2)使用flat_map()、filter()等

#使用‘Dataset.flat_map()’将每个文件转换为一个单独的嵌套数据集
#然后将它们的内容顺序连接成一个单一的“扁平”数据集
#跳过第一行(标题行)
#过滤以“#”开头的行

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]

dataset = tf.data.Dataset.from_tensor_slices(filenames)

dataset = dataset.flat_map(
    lambda filename: (
        tf.data.TextLineDataset(filename)
        .skip(1)
        .filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值