tensorflow数据读取——文件名队列和内存队列

详细信息请参考:https://zhuanlan.zhihu.com/p/27238630

以从文件中读入图像数据为例。


一、tensorflow读取机制图解

为了提高GPU/CPU的对数据的运算效率,引入“内存”的概念,我们把“读入数据到内存队列”和“GPU/CPU计算数据”分别放入两个线程中,其中”读入数据到内存队列”的线程的图示如下:


为了方便管理,还需要在“内存队列”前加入“文件名队列”:


然后有没有发现,红色框里流程特别像工厂流水线上的女工在勤劳工作,这就是“文件读取管线”的概念了。程序运行后

首先,把ABC依次放入“文件名队列并在之后标注队列结束;

然后,“内存队列”获得“文件名队列”中的ABC(坑:获得的顺序也有可能是CBA或BCA,代码部分会给出解释);

最后,系统检测到“结束”就可以结束程序了。

以上就是tensorflow中读取数据的基本机制。


二、代码详解

所谓“机制”不过是为了简化理解一个过程的概念而已。下面通过代码来看看每句命令对应的状态。注:以读入一张“A.jpg”的3*3*3的图为例(如下图,虽然看起来是灰度图,但实际是3通道的)。


1.image_name = ['./A.jpg']

这一句好理解,获取文件名。

2.filename_queue = tf.train.string_input_producer(image_name,shuffle=False)

tf.train.string_input_producer()表示创建“文件名队列”,注意这里仅仅只是创建哦!创建后,整个系统还是处于“停滞状态”,文件名并没有被加入到队列中(如下图所示)此时如果我们开始计算,因为内存队列中什么也没有,计算单元就会一直等待,导致整个系统被阻塞。也就说女工们已经就位,第一道工序还没开始,大家就都没活干,得等着。

填坑:然后注意到参数shuffle = False,意思是要从“内存队列”中顺序获得“文件名队列”得到A、B、C,如果是shuffle=True(默认),那么就会乱序获取到“内存队列”,结果变为CBA或BCA等。当然我们只读入一张图“A.jpg”的话,shuffle为False还是True都无所谓。


3 image_reader = tf.WholeFileReader()
4._,image_file = image_reader.read(filename_queue)

实际上在tensorflow中,内存队列不需要我们自己建立,我们只需要使用reader对象从文件名队列中读取数据就可以了,这里读取后的数据保存在 image_file 中。


5.image = tf.image.decode_jpeg(image_file,channels=3)

对读取的图片解码成jpg的格式。


6. coord = tf.train.Coordinator() #协同启动的线程
7. threads = tf.train.start_queue_runners(sess=sess, coord=coord) #启动线程运行队列
8. print(sess.run(image))
9. coord.request_stop() #停止所有的线程
10.coord.join(threads)

刚才说过,队列只被创建了,要打破僵局,需要使用tf.train.start_queue_runners(),才能启动填充队列,这时系统不再“停滞”,整个系统才能跑起来。该函数一般搭配Coordinator一起使用,这是负责在收到任何关闭信号的时候,让所有的线程都知道。本人一开始也是没有写这一句,导致程序一直停滞。

以下是完整的代码和显示结果:

import tensorflow as tf

sess = tf.Session()

image_name = ['./A.jpg']
filename_queue = tf.train.string_input_producer(image_name)
image_reader = tf.WholeFileReader()
_,image_file = image_reader.read(filename_queue)
image = tf.image.decode_jpeg(image_file,channels=3)

coord = tf.train.Coordinator() #协同启动的线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord) #启动线程运行队列
print(sess.run(image))
coord.request_stop() #停止所有的线程
coord.join(threads)
结果显示:

以上每一块代表图像的一行信息,所以显示共有三个分块,每一块包含所有列的信息。




阅读更多
个人分类: tensorflow
下一篇Win + tensorflow第三方库安装——numpy、matplotlib...
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭