TensorFlow Queue队列的使用

TensorFlow Queue队列的使用

TensorFlow队列可以采用多个线程来分别读取,训练模型,TensorFlow官方推荐使用队列来读取输入数据(由于队列比较麻烦,不够简单直白,现在推出了一个tf.data API),相比于采用第三方库加载预处理数据,然后把这些数据采用feed_dict方法feed进模型,队列会更加高效。不过队列真的有些复杂,繁琐,我研究了好久(可能是我太笨了)。。。。。。下面就介绍下我对队列的理解。

主要包括以下几个方面:

  • 产生文件名队列 tf.train.string_input_producer
  • 文件读取
  • 文件解码
  • 生成Batch

tf.train.string_input_producer函数

tf.train.string_input_producer函数用来产生文件名队列,它的输入为文件名列表或者单个的文件名,这些文件保存的是训练或评价时的数据。比如图像分类问题中,这些文件中就保存的是图像数据和label数据。文件格式有很多种,不同的格式对应不同的read和decode方式。如:TFRecord格式(TensorFlow的standard文件格式)。

文件读取

当文件传入tf.train.string_input_producer函数后,就形成了文件队列,然后就可以对文件进行读取。传入tf.train.string_input_producer函数的文件格式不同,对应不同的文件读取函数。TensorFlow提供了几个读取函数,如tf.TFRecordReader,tf.FixedLengthRecordReader,它们分别读取TFRecord格式的文件和每个记录都是定长度的二进制文件。这些读取函数会返回两个值,key和value。key相当于对读取的内容做一个标记,以便调试用,相当于说从哪个文件读取了哪个记录;value为读取的内容,不过为string value,这些值还不能使用,需要对读取到的string value值进行解码,这就需要解码函数。

文件解码

不同的文件格式对应不同的解码函数,如:tf.FixedLengthRecordReader读取的内容一般采用tf.decode_raw进行解码,它把一个string tensor转为uint8 tensor。把数据解码出来后就可以进行一些预处理操作了,如去均值,左右翻转等。

生成Batch

在对数据进行解码和预处理后就要把数据生产一个一个的Batch作为训练或评价时的batch data。TensorFlow提供了两个函数:tf.train.shuffle_batch、tf.train.shuffle_batch_join。tf.train.shuffle_batch函数接受预处理后的数据进来,可以设置batch size大小等,它把数据进行shuffle后生成batch size大小的模型输入数据。当采用多个线程进行文件读取时(调用多次文件读取函数),就需要采用tf.train.shuffle_batch_join函数,它和tf.train.shuffle_batch函数基本相同,只不过它接受的是一个存放多个线程读取的数据的列表。到这里就完成了模型输入数据的读取和生成。

最后放一张官方关于队列的图:

这里写图片描述

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值