tf.train.Coordinator

TensorFlow里与Queue有关的概念和用法。

  • Queue 是TensorFlow队列和缓存机制的实现。
  • QueueRunner 是TensorFlow中对操作Queue的线程的封装。
  • Coordinator 是TensorFlow中用来协调线程预先的工具。

Queue

根据实现的方式不同,可以分成一下几类:

  • tf.FIFOQueue 按入列顺序出列的队列
  • tf.RandomShuffleQueue 随机顺序出列的队列
  • tf.PaddingFIFOQueue 以固定长度批量出列的队列
  • tf.PriorityQueue 带优先级出列的队列
  • ...

这些类型的Queue出列自身的性质不太一样,创建、使用的方法基本是相同的。

# 创建函数的参数
tf.FIFOQueue(capacity, dtypes, shapes=None, names=None ...)

Queue主要包含入列(enqueue)和出列(dequeue)两个操作。enqueue操作返回计算图中的一个Operation节点,dequeue操作返回一个Tensor值。Tensor在创建时同样只是一个定义或称为声明,需要放在Session中运行才能获得真正的数值。

import tensorflow as tf
tf.InteractiveSession()
q = tf.FIFOQueue(2, 'float')
init = q.enqueue_many(([0,0],))
x = q.dequeue()
y = x + 1
q_inc = q.enqueue([y])
init.run()
q_inc.run()
q_inc.run()
x.eval()
x.eval()
x.eval()

在这里插入图片描述
如果一次性入列超过Queue Size的数据,enqueue操作会卡住,直到有数据(被其他线程)从队列取出。对一个已经取空的队列使用dequeue操作也会卡住,直到有新的数据(从其他线程)写入。

QueueRunner

TensorFlow的计算主要在使用CPU/GPU和内存,而数据读取涉及磁盘操作,速度远低于计算。因此通常会使用多个线程读取数据,然后使用一个线程消费数据。QueueRunner就是用来管理这些读写队列的线程。
QueueRunner需要与Queue一起使用,但并不一定必须使用Coordinator。

import tensorflow as tf

q = tf.FIFOQueue(10, 'float')
counter = tf.Variable(0.0) #计数器
# 给计数器加一
increment_op = tf.assign_add(counter, 1.0)
#将计数器加入队列
enqueue_op = q.enqueue(counter)
# 创建QueueRunner
# 用多个线程像队列添加数据
# 这里实际创建了4个线程,两个增加计数,两个执行入队
qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2)
# 主线程
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
#启动入队线程
qr.create_threads(sess, start=True)
for i in range(20):
    print(sess.run(q.dequeue()))
# 输出
0.0
0.0
0.0
2.0
2.0
3.0
5.0
5.0
6.0
7.0
7.0
8.0
10.0
10.0
17.0
26.0
40.0
49.0
57.0
72.0

增加计数的进程会不停的后台运行,执行入队的进程会先执行10次(因为队列长度只有10),然后主线程开始消费数据,当一部分数据消费完后,入队的进程又会开始执行。最终主线程消费完20个数据后停止,但其他线程进行运行,程序不会结束。

Coordinator

Coordinator是用来保存线程组运行状态的协调器对象,它和TensorFlow的Queue没有必然关系,是可以单独和Python线程使用的。

Coordinator和Python线程使用

import tensorflow as tf
import threading, time
# 子线程函数
def loop(coord, id):
    t = 0
    while not coord.should_stop():
        print(id)
        time.sleep(1)
        t += 1
        # 只有1号线程调用request_stop方法
        if (t >= 2 and id == 1):
            coord.request_stop()
# 主线程
coord = tf.train.Coordinator()
# 使用Python API创建10个线程
threads = [threading.Thread(target=loop, args=(coord, i)) for i in range(10)]
# 启动所有线程,并等待线程结束
for t in threads: 
    t.start()
coord.join(threads)
# 输出
0
1
2
3
4
5
6
7
8
9
0
1
3
8
4
2
9
5
6
7
0
3

将这个程序运行起来,会发现所有的子线程执行完两个周期后都会停止,主线程会等待所有子线程都停止后结束,从而整个程序结束。只要有任何一个线程调用了Coordinator的request_stop方法,所有的线程都可以通过should_stop方法感知并停止当前线程。

将QueueRunner和Coordinator一起使用,是为了封装这个判断操作,从而使任何一个线程出现异常时,能够正常结束整个程序,同时主线程也可以直接调用request_stop方法来停止所有子线程的执行。

Coordinator和TensorFlow使用

TensorFlow的Session对象支持多线程,可以在同一个会话(Session)中创建多个线程,并行执行。在Session中的所有线程都必须能被同步终止,异常必须能被正确捕获并报告,会话终止的时候,队列必须能被正确的关闭。
tf.train.Coordinator是tensorflow的一个多线程管理器,Coordinator类用来管理在Session中的多个线程,可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,该线程捕获到这个异常之后就会终止所有线程。使用 tf.train.Coordinator()来创建一个线程管理器(协调器)对象。
在这里插入图片描述

Queue的两种经典模式

显示的创建QueueRunner,调用create_threads方法启动线程

import tensorflow as tf
import numpy as np
# 1000个4维输入向量,每个数据取值为1-10之间的随机数
data = 10 * np.random.randn(1000, 4) + 1
# 1000个随机的目标值,值为0或1
target = np.random.randint(0, 2, size=1000)
# 创建Queue,队列中每一项包含一个输入数据和相应的目标值
queue = tf.FIFOQueue(capacity=50, dtypes=[tf.float32, tf.int32], shapes=[[4], []])
# 批量入列数据(这是一个Operation)
enqueue_op = queue.enqueue_many([data, target])
# 出列数据(这是一个Tensor定义)
data_sample, label_sample = queue.dequeue()
# 创建包含4个线程的QueueRunner
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)
with tf.Session() as sess:
    # 创建Coordinator
    coord = tf.train.Coordinator()
    # 启动QueueRunner管理的线程
    enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
    # 主线程,消费100个数据
    for step in range(100):
        if coord.should_stop():
            break
        data_batch, label_batch = sess.run([data_sample, label_sample])
    # 主线程计算完成,停止所有采集数据的进程
    coord.request_stop()
    coord.join(enqueue_threads)

使用全局的start_queue_runners方法启动线程

import tensorflow as tf

# 同时打开多个文件,显示创建Queue,同时隐含了QueueRunner的创建
filename_queue = tf.train.string_input_producer(['data1.csv', 'data2.csv'])
reader = tf.TestLineReader(skip_header_lines=1)
#Tensorflow的Reader对象可以直接接受一个Queue作为输入
key, value = render.read(filename_queue)

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    # 启动计算图中所有的队列线程
    threads = tf.train.start_queue_runners(coord=coord)
    # 主线程,消费100个线程
    for _ in range(100):
        features, labels = sess.run([key, value])
    # 主线程计算完成,停止所有采集数据的进程
    coord.request_stop()
    coord.join(threads)

参考资料
理解TensorFlow的Queue
多线程输入数据处理框架
TensorFlow之线程管理器tf.train.Coordinator
tf.train.Coordinator

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值