目录
多线程输入数据处理框架
为了避免图像预处理成为神经网络模型训练效率的瓶颈,TF提供了一套多线程处理输入数据的框架。经典的输入数据处理流程为如图所示。
队列与多线程
在tensorflow中,队列不仅是一种数据结构,更提供了多线程机制,队列也是TF中多线程输入数据处理框架的基础。比如多个线程可以同时向一个队列中写元素,或者同时读取一个队列中的元素。
队列和变量都是计算图上有状态的节点。通过赋值可以修改变量的取值;通过Enqueue(入队)、EnqueueMany(队列初始化)、Dequeue(出队)来修改队列状态。以下程序展示了如何使用这些函数来操作一个队列。
import tensorflow as tf
#创建一个先进先出队列,指定队列中最多可以保存两个元素,并指定类型为整数。
q = tf.FIFOQueue(2, 'int32')
#使用enqueue_many函数来初始化队列元素,在使用队列之前要明确调用这个初始化过程
init = q.enqueue_many(([0, 10],))
#使用Dequeue函数将队列中第一个元素出队列,这个元素将被存在变量x
x = q.dequeue()
y = x + 1
#重新加入队列
q_inc = q.enqueue([y])
with tf.Session() as sess:
#运行初始化队列的操作
sess.run(init)
for _ in range(5):
v, _ = sess.run([x, q_inc])
print(v)
# 0,10,1,11,2
在TF中提供了FIFOQueue和RandomShuffleQueue两种队列。在上面的程序中展示了FIFOQueue类的先进先出队列。而RandomShuffleQueue会将队列中的元素打乱,每次出队列操作得到的是从当前队列中随机选择的一个元素。
利用队列,tensorflow可以实现多线程数据处理。TF提供了tf.Coordinator和tf.QueueRunner两个类来完成多线程协同的功能。
tf.Coordinator主要用于协同多个线程一起停止,提供了should_stop,request_stop和join三个函数。
- 在启动线程时,要先声明一个tf.Coordinator类,并将这个类传入创建的每一个线程中。
- 启动的线程一直查询程should_stop状态,只有当其为True时则退出。
- 每一个启动的进程可以通过调用request_stop函数来通知其他线程退出。
import tensorflow as tf
import numpy as np
import threading
import time
#在线程中运行的程序,这个程序每隔1s判断是否需要停止打印自己的id
def MyLoop(coord, worker_id):
#使用tf.Coordinator()类来判断当前进程是否需要停止
while not coord.should_stop():
if np.random.rand() < 0.1:
print('Stoping from id:%d' % worker_id)
# 调用request_stop来通知其他进程停止
coord.request_stop()
else:
print('Working on id:%d'% worker_id)
time.sleep(1)
#声明一个tf.train.Coordinator()类
coord = tf.train.Coordinator()
#创建五个线程
threads = [threading.Thread(target=MyLoop, args=(coord,i,))for i in range(5)]
#启动所有的线程
for t in