tf.data输入的简单介绍

以前常用的tensorflow输入处理的方法有二种:
1、feed_dict
2、Queues
前者使用灵活但效率较低,后者使用复杂但效率较高。

tf.data API正好解决了两者的缺点,而且使用方便,速度比以上两种都快。

构造dataset

tf.data.Dataset.from_tensors((features, labels))  #从一个tensor tuple创建一个单元素的dataset
tf.data.Dataset.from_tensor_slices((features, labels))  #创建一个包含多个元素的dataset
tf.data.TextLineDataset(filenames)  #读取一个文件名列表,将每个文件中的每一行作为一个元素
tf.data.TFRecordDataset(filenames) #读取硬盘中的TFRecord格式文件,构造dataset

dataset.map(lambda x: tf.decode_jpeg(x))  #用map对dataset中的每个元素进行处理
dataset.repeat(NUM_EPOCHS)  #将dataset重复一定数目的次数用于多个epoch的训练
dataset.batch(BATCH_SIZE) #将原来的dataset中的元素按照某个数量叠在一起,生成mini batch

以上是常用函数,TensorFlow 1.4 版本中还允许用户通过Python的生成器构造dataset,如下:

def generator():
  while True:
    yield ...

dataset = tf.data.Dataset.from_generator(generator, tf.int32)

构造dataset的一个常用代码段:

dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = dataset.map(lambda x: )
dataset = dataset.shuffle(10000)
dataset = dataset.repeat(100)  #100个epoch
dataset = dataset.batch(128)   #batch_size为128

迭代器

数据集定义好后,相应的采用迭代器访问数据,由简单到复杂,主要有以下几种迭代器
1、one-shot
2、initializable
3、reinitializable
4、feedable

one-shot

代码实例如下:

dataset = tf.data.Dataset.from_tensor_slices(data)
dataset.batch(128)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(epoch):
    try:
         sess.run(next_element)
     except tf.errors.OutOfRangeError:          #数据遍历一遍结束
         ...

initializable

使用之前显式的通过调用iterator.initializer操作初始化,这使得在定义数据集时可以结合tf.placeholder传入参数

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

#Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
  value = sess.run(next_element)
  assert i == value

#Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
  value = sess.run(next_element)
  assert i == value

reinitializable

reinitializable iterator 可以被不同的 dataset 对象初始化,比如对于训练集进行了shuffle的操作,对于验证集则没有处理,通常这种情况会使用两个具有相同结构的dataset对象,很常用

training_dataset = tf.data.Dataset.range(100).map(lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

iterator = tf.Iterator.from_structure(training_dataset.output_types,training_dataset.output_shapes)
next_element = iterator.get_next()

training_init = iterator.make_initializer(training_dataset)
validation_init = iterator.make_initializer(validation_dataset)

for _ in range(epoch):
    sess.run(training_init_op)         #train初始化
        try:
            sess.run(next_element)
        except tf.errors.OutOfRangeError: 
            ...
    sess.run(validation_init_op)    #test初始化
        try:
            sess.run(next_element)
        except tf.errors.OutOfRangeError: 
            ...

feedable

feedable iterator 可以通过和tf.placeholder结合在一起,同通过feed_dict机制来选择在每次调用tf.Session.run的时候选择哪种Iterator。

training_dataset = tf.data.Dataset.range(100).map( lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

while True:
  for _ in range(200):
    sess.run(next_element, feed_dict={handle: training_handle})

  sess.run(validation_iterator.initializer)
  for _ in range(50):
    sess.run(next_element, feed_dict={handle: validation_handle})
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值