tensorflow high level API---import data

一、基本机制

其实就是tf.data接口可以更好的处理大规模的数据和各种数据类型,还有处理复杂的转换。

(1)tf.data.Datasets代表了一个元素的序列,着每一个元素包含了一个或者多个张量实体。有两种创建数据集的方法:第一种(创造一个源)是通过Dataset.from_tensor_slices来构建一个数据集从一个or多个张量实体;第二种是应用一种转换,例如Dataset.batch()构建一个数据集从一个或者多个tf.data.Dataset实体。

(2)td.data.Iterator:提供了主要的一种从数据集上取得元素的方法。Iterator.get_next()可以获取数据集的下一个元素。

第一你需要定义一个源,例如,你可以在内存的张量上构建数据集,使用tf.data.Dataset.from_tensors()或者tf.data.Dataset.from_tensor_slices().当然如果在硬盘上你有以TFRecord形式存储的数据,你可以建立tf.data.TFRecordDataset。

一旦你有了dataset实体,你可以通过tf.data.Dataset实体的一些方法来转化成一个新的数据集。你可以通过Dataset.map()将每个元素进行转化,也可以通过Dataset.batch()将多个元素进行转换。

最常见的消耗数据的方法就是在数据集上建立一个迭代器,可以通过Dataset.make_one_shot_iterator()来每次获取数据集的一个元素。tf.data.Iterator提供了两个操作,第一个是Iterator.initializer确保你初始化或者再次初始化你的迭代状态,第二个是Iterator.get_next()获取下一个数据。

1.1 数据结构

一个数据集拥有很多个结构相同的元素,每个元素包含一个或者多个张量,称之为组件compoents.每一个组件有个tf.Dtype的属性可以表示元素的属性,tf.TensorShape表明了每一个静态元素的形状。Dataset.output_types和Dataset.output_shapes表明了每个数据集元素的组件的类型和结构。

案例:

from __future__ import absolute_import,division,print_function
import tensorflow as tf

dataset1=tf.data.Dataset.from_tensor_slices(tf.random_uniform([4,10]))
print(dataset1.output_shapes)
print(dataset1.output_types)

dataset2=tf.data.Dataset.from_tensor_slices(
    (tf.random_uniform([4]),
     tf.random_uniform([4,100],maxval=100,dtype=tf.int32))
)
print(dataset2.output_types)
print(dataset2.output_shapes)

dataset3=tf.data.Dataset.zip((dataset1,dataset2))
print(dataset3.output_types)
print(dataset3.output_shapes)

结果:

(10,)
<dtype: 'float32'>
(tf.float32, tf.int32)
(TensorShape([]), TensorShape([Dimension(100)]))
(tf.float32, (tf.float32, tf.int32))
(TensorShape([Dimension(10)]), (TensorShape([]), TensorShape([Dimension(100)])))

每个元素的组件起个名字非常的方便,代码:

dataset2=tf.data.Dataset.from_tensor_slices(
    {'a':tf.random_uniform([4]),
     'b':tf.random_uniform([4,100],maxval=100,dtype=tf.int32)}
)
print(dataset2.output_types)
print(dataset2.output_shapes)

结果:

(10,)
<dtype: 'float32'>
{'a': tf.float32, 'b': tf.int32}
{'a': TensorShape([]), 'b': TensorShape([Dimension(100)])}

代码:你可以通过下面的方法对数据集的数据进行转变。

dataset1 = dataset1.map(lambda x: ...)

dataset2 = dataset2.flat_map(lambda x, y: ...)

# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)

1.2 创造迭代器

迭代器的类型:

one-shot、initializable、reinitializable、feedable

(1)one-shot最简单,不需要特殊的初始化操作,仅仅在数据集上迭代一次。这个支持处理几乎所有的现存的基于队列的输入流,但是不支持参数。Dataset.range()案例。只有这个可以被estimator使用。

from __future__ import absolute_import,division,print_function
import tensorflow as tf

dataset=tf.data.Dataset.range(100)
iterator=dataset.make_one_shot_iterator()
next_element=iterator.get_next()
with tf.Session() as sess:
    for i in range(100):
        value=sess.run(next_element)
        assert i==value

(2)initializer需要你运行一个iterator.initializer操作。当然你可以对你的数据集进行参数化,还可以使用多个tf.placeholder()可以在你初始化的时候传入。

max_value=tf.placeholder(tf.int64,shape=[])
dataset=tf.data.Dataset.range(max_value)
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer,feed_dict={max_value:10})
    for i in range(10):
        value=sess.run(next_element)
        assert i==value
    sess.run(iterator.initializer,feed_dict={max_value:100})
    for i in range(100):
        value=sess.run(next_element)
        assert i==value

(3)reinitializable:是一个可以从多个不同的数据集实体初始化的迭代器。例如你可能有一个训练数据的输入流,你会添加上一些扰乱从而来提高模型的泛化能力。然后你可能还有个交叉验证集的输入来在未修改的数据上进行评估结果。这些输入线使用的是不同的数据集,但是每一个组件是存在相同的类型和可兼容的shape的。

from __future__ import absolute_import,division,print_function
import tensorflow as tf

training_dataset=tf.data.Dataset.range(100).map(
    lambda x:x+tf.random_uniform([],-10,10,tf.int64)
)
validation_dataset=tf.data.Dataset.range(50)

#这里的reinitializable迭代器可以从训练的数据集,当然也可以从交叉验证的数据集上获取类型和shape,兼容性
iterator=
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值