tensorflow中的队列和线程

一、队列

tensorflow中主要有FIFOQueue和RandomShuffleQueue两种队列,下面就详细介绍这两种队列的使用方法和应用场景。

1、FIFOQueue

FIFOQueue是先进先出队列,主要是针对一些序列样本。如:在使用循环神经网络的时候,需要处理语音、文字、视频等序列信息的时候,我们希望处理的时候能够按照顺序进行,这时候就需要使用FIFOQueue队列。

    #先入先出队列,初始化队列,设置队列大小5
    q = tf.FIFOQueue(5,"float")
    #入队操作
    init = q.enqueue_many(([1,2,3,4,5],))
    #定义出队操作
    x = q.dequeue()
    y = x + 1
    #将出队的元素加1,然后再加入到队列中
    q_in = q.enqueue([y])
    #创建会话
    with tf.Session() as sess:
        sess.run(init)
        #执行3次q_in操作
        for i in range(3):
            sess.run(q_in)
        #获取队列的长度
        que_len = sess.run(q.size())
        #将队列中的所有元素执行出队操作
        for i in range(que_len):
            print(sess.run(q.dequeue()))

2、RandomShuffleQueue

RandomShuffleQueue是随机队列,队列在执行出队操作的时候,是以随机的顺序进行的。随机队列一般应用在我们训练模型的时候,希望可以无序的获取样本来进行训练,如:在训练图像分类模型的时候,需要输入的样本是无序的,就可以利用多线程来读取样本,将样本放到随机队列中,然后再利用主线程每次从随机队列中获取一个batch进行模型的训练。

    #初始化一个随机队列,设置队列大小为10,最小长度为2
    q = tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes="float")
    #创建会话
    with tf.Session() as sess:
        #定义10次入队操作
        for i in range(10):
            sess.run(q.enqueue(i))
        #定义8次出队操作
        for i in range(8):
            print(sess.run(q.dequeue()))

注意:在使用随机队列的时候,我们设置了队列的容量为10,最小长度为2。当队列的长度已经等于队列的容量(10)再执行入队操作或队列的长度已经等于最小长度(2)再执行出队操作时,程序会发生阻断,即程序在执行,但是没有任何输出,如下图:

定义了10次出队操作,当队列出队8次之后,就被阻断了。我们可以通过设置会话在运行时的等待时间来解除阻断:

    #初始化一个随机队列,设置队列大小为10,最小长度为2
    q = tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes="float")
    #创建会话
    with tf.Session() as sess:
        #定义10次入队操作
        for i in range(10):
            sess.run(q.enqueue(i))
        #设置会话运行时等待时间,等待时长为5s
        run_options = tf.RunOptions(timeout_in_ms=5000)
        #定义10次出队操作
        for i in range(10):
            try:
                #当队列进入阻断之后,超时就抛出异常
                print(sess.run(q.dequeue(),options=run_options))
            except tf.errors.DeadlineExceededError:
                print("out of range")
                #退出循环
                break

当队列出队第9次的时候,进入阻断状态时,我们可以通过DeadlineExceededError来捕获阻断信息。

二、队列管理器

在训练模型的时候,我们需要将样本从硬盘读取到内存之后,才能进行训练。会话中可以运行多个线程,我们可以在队列管理器中创建一系列新的线程进行入队操作,主线程可以利用队列中的数据进行训练,而不需要等到所有的样本都读取完成之后才开始训练,即数据的读取和模型的训练是异步的,这样可以节省不少时间。

    #创建队列,设置队列的大小为1000
    q = tf.FIFOQueue(1000,"float")
    #定义计数器
    counter = tf.Variable(0.0)
    #给计数器加1
    increment_op = tf.assign_add(counter,tf.constant(1.0))
    #队列入队操作
    enque_op = q.enqueue([counter])
    #创建队列管理器
    qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enque_op]*1)
    #创建会话
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        #启动入队线程
        enqueue_threads = qr.create_threads(sess,start=True)
        #主线程
        for i in range(10):
            #定义出队操作
            print(sess.run(q.dequeue()))

程序结束的时候,还报了一个tensorflow.python.framework.errors_impl.CancelledError: Enqueue operation was cancelled的异常。那是因为主线程已经完成了,入队线程还在继续执行导致程序没法结束从而报错。由于计数器加1操作和入队操作不同步,可能会由于计数器还没来得及进行加1操作就再次被执行入队操作,从而导致多次入队同样的数字,也就是为什么出队的时候会出现同样的数字。

三、协调器

为了避免上述异常的发生,我们可以通过协调器来实现线程间的同步,来终止其他线程。

    #创建队列,设置队列的大小为1000
    q = tf.FIFOQueue(1000,"float")
    #定义计数器
    counter = tf.Variable(0.0)
    #给计数器加1
    increment_op = tf.assign_add(counter,tf.constant(1.0))
    #队列入队操作
    enque_op = q.enqueue([counter])
    #创建队列管理器
    qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enque_op]*1)
    #创建会话
    with tf.Session() as sess:
        #初始化变量
        sess.run(tf.global_variables_initializer())
        #创建一个线程协调器
        coord = tf.train.Coordinator()
        #启动入队线程
        enqueue_threads = qr.create_threads(sess,coord=coord,start=True)
        #主线程执行出队操作
        for i in range(10):
            print(sess.run(q.dequeue()))
        #通知其他线程关闭
        coord.request_stop()
        #等待其他线程结束,当其他线程都关闭之后,函数才返回结果
        coord.join(enqueue_threads)

通过上面的结果可以发现,程序能够正常的结束。但是,当关闭线程之后再执行出队操作,就会报OutOfRangeError的错误,代码如下

        coord.request_stop()
        for i in range(10):
            print(sess.run(q.dequeue()))
        coord.join(enqueue_threads)

对于这种情况,我们可以通过OutOfRangeError来捕获这个错误信息

        coord.request_stop()
        for i in range(10):
            try:
                print(sess.run(q.dequeue()))
            except tf.errors.OutOfRangeError:
                #退出循环
                break
        coord.join(enqueue_threads)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

修炼之路

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值