TensorFlow读取TFRecord数据:使用tf.data.TFRecordDataset读取和使用线程、队列读取的比较(接上一篇博客)

0 背景

im2txt代码中读取TFRecord数据使用的是多线程填充队列的方式,从tensorflow 1.4.0(大概是)之后推荐使用tf.data模块进行操作。使用tf.data进行数据读取比较好懂,因为上层的封装比较好,流程比较清楚;但是使用多线程和队列来进行数据获取,虽然整个流程大概是清楚了,但总觉得有什么地方有点迷糊。

1 tf.data模块读取TFRecord数据

看了一下最简单的示例,大概流程可以概括如下:

  1. 利用tf.data.TFRecordDataset()打开一个或多个TFRecord文件;
  2. 定义一个函数parser()指明对TFRecord文件中每一条数据的处理(包括具体数据的获取和预处理等),并使用map(parser)函数对每条数据应用该处理;
  3. 可选择性的对数据进行打乱、按批次获取数据及重复数据;
  4. 构造迭代器,使用get_next()迭代获取数据。

示例代码如下:

import tensorflow as tf
# 定义为每一条数据的处理,这里提取了图像id号和描述id序列
def parser(serialized):
    context, sequence = tf.parse_single_sequence_example(
        serialized, 
        context_features={
            'image/image_id': tf.FixedLenFeature([], dtype=tf.int64)
        },
        sequence_features={
            'image/caption_ids': tf.FixedLenSequenceFeature([], dtype=tf.int64)
        })
    image_id = context['image/image_id']
    caption = sequence['image/caption_ids']
    return image_id, caption
    
# 1、使用TFRecordDataset打开一个TFRecord文件,并对每一条数据应用上面定义的操作 
dataset = tf.data.TFRecordDataset('../data/mscoco/test-00007-of-00008')
dataset = dataset.map(parser)
# 可选择性的对数据进行打乱,buffer_size参数等效于tf.train.shuffle_batch的min_after_dequeue参数
# dataset = dataset.shuffle(buffer_size)
# 对数据按批量读取,执行该操作需保证数据的维度是一致的
# dataset = dataset.batch(batch_size)
# 对数据进行重复N份
# dataset = dataset.repeat(N)

# 2、创建迭代器
iterator = dataset.make_one_shot_iterator()
# 3、利用迭代器获取下一条数据
image, caption = iterator.get_next()

with tf.Session() as sess:
    i = 0
    while True:
        try:
            i += 1
            # 顺序获取数据,打印输出
            im, cap = sess.run([image, caption])        
            print(i, im, cap)
        except tf.errors.OutOfRangeError:
            break

代码会顺序读取TFRecord文件的每一条数据,并把其中的每一条数据中的图像id和描述单词的id输出。

但是,使用tf.data.TFRecordDataset()读取TFRecord数据,文件中的数据读完了就会抛出一个OutOfRangeError异常,如果设定了训练迭代步数,应该需要事先计算数据复制的份数,以保证训练正常进行?

2 使用线程和队列读取TFRecord数据

im2txt代码中这部分的流程可以总结如下:

  1. 使用tf.train.string_input_producer()创建一个文件名队列;
  2. 创建一个FIFOQueue或者RandomShuffleQueue作为数据队列,使用read()函数从文件名队列中读取一条数据,直接将其压入enqueue()前面创建的数据队列(实际上也可以不用压入队列,直接使用read()读取一条数据之后就可以使用tf.parse_single_sequence_example()对该条数据进行解析,这个地方压入队列是因为后面从数据队列中读数据使用了多个线程);
  3. 从数据队列中弹出dequeue()一条数据(默认是4个线程),然后使用tf.parse_single_sequence_example()对该条数据解析,主要是提取其中的图像数据和对应的描述数据(每个单词的id号);
  4. im2txt里面会对解析出来数据进行又一番处理,主要是从描述数据中提取输入序列、输出序列和序列掩码供训练时使用;然后会利用tf.train.batch_join()函数收集训练时每一个batch要使用的数据。

测试代码如下:

import tensorflow as tf

data_files = ['../data/mscoco/test-00007-of-00008']
reader = tf.TFRecordReader()
# 1、创建文件名队列
filename_queue = tf.train.string_input_producer(data_files, shuffle=False, capacity=16, name='filename_queue')

min_queue_examples = 2300*2
capacity = min_queue_examples+100*32
#values_queue = tf.RandomShuffleQueue(capacity=capacity, min_after_dequeue=min_queue_examples,
#                                    dtypes=[tf.string], name='random_input_queue')
# 为了验证数据的读取顺序是不是和前面tf.data.TFRecordDataset()一致,这里使用的是FIFOQueue进行测试
values_queue = tf.FIFOQueue(capacity=capacity, dtypes=[tf.string], name='random_input_queue')

# 2、从文件名队列中读取一条数据,并把该条数据入队到数据队列values_queue中
key, value = reader.read(filename_queue)
enqueue_ops = []
enqueue_ops.append(values_queue.enqueue([value]))  #入队操作
# 使用tf.train.QueueRunner创建线程运行队列的数据入队操作,这里相当于创建了1个线程
tf.train.add_queue_runner(tf.train.QueueRunner(values_queue, enqueue_ops))

images_and_captions = []
# 3、从数据队列values_queue中出队一条数据,并解析该条数据,这里只是提取了数据中的图像id和描述单词id
serialized = values_queue.dequeue()
context, sequence = tf.parse_single_sequence_example(
				serialized, 
				context_features={
					'image/image_id': tf.FixedLenFeature([], dtype=tf.int64)
				}, 
                sequence_features={
                	'image/caption_ids': tf.FixedLenSequenceFeature([], dtype=tf.int64)
                })
image_id = context['image/image_id']
caption = sequence['image/caption_ids']
images_and_captions = [image_id, caption]

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)    
    i = 0
    while True:
        try:
            i += 1
            data = sess.run(images_and_captions)
            print(i, data[0], data[1])
        except tf.errors.OutOfRangeError:
            break

最后输出的结果和前面使用tf.data.TFRecordDataset()读取TFRecord数据的结果一样,也会顺序读取TFRecord文件的每一条数据,并把其中的每一条数据中的图像id和描述单词的id输出(因为数据队列使用的FIFOQueue进行的定义)。如果数据队列使用RandomShuffleQueue定义的话,则会打乱输出数据的顺序。

不过不同的是:使用线程和队列获取数据的时候,读完数据之后会从头再次读取数据,因此不需要人为的控制数据的复制。

3 其他

这个时候,似乎又涉及到了另一个方面:TensorFlow训练过程数据的送入方式。(待查资料…

  1. 之前大多接触的,使用占位符placeholder实现数据的传入
  2. 使用线程和队列进行数据的输入,不需要占位符的定义(疑惑:提到说需要使用tf.train.Coordinator来协同启动的线程,并且由于使用了tf.train.QueueRunner()创建线程,因此需要明确调用tf.train.start_queue_runners来启动所有线程,否则没有线程执行入队操作,执行出队操作时线程会被挂起。但是,im2txt里面好像没有start_queue_runnners。)
  3. Dataset的使用。。。(还不是清楚使用Dataset怎样把数据送到训练模型中,可直接用于替换方式2?)
  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值