range_input_producer多线程读取数据功能介绍

先上代码:

import tensorflow as tf
BATCH_SIZE = 6
NUM_EXPOCHES = 5
sess = tf.Session()
array = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
array = list(map(lambda line: line, array))
i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
sess.run(tf.initialize_local_variables())
#注意num_epochs为局部变量(local variables),必须紧接着马上初始化局部变量。
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

sess.run(tf.initialize_all_variables())

try:
    index = 0
    while not coord.should_stop() and index<10:
        datalines = sess.run(inputs)
        index += 1
        print("step: %d, batch data: %s" % (index, str(datalines)))
except tf.errors.OutOfRangeError:
    print("Done traing:-------Epoch limit reached")
except KeyboardInterrupt:
    print("keyboard interrput detected, stop training")
finally:
    coord.request_stop()
coord.join(threads)
sess.close()
del sess

执行结果如下:

step: 1, batch data: [19 20 21 22 23 24]
step: 2, batch data: [ 7  8  9 10 11 12]
step: 3, batch data: [25 26 27 28 29 30]
step: 4, batch data: [13 14 15 16 17 18]
step: 5, batch data: [1 2 3 4 5 6]
step: 6, batch data: [13 14 15 16 17 18]
step: 7, batch data: [19 20 21 22 23 24]
step: 8, batch data: [ 7  8  9 10 11 12]
step: 9, batch data: [1 2 3 4 5 6]
step: 10, batch data: [25 26 27 28 29 30]

首先考虑代码:

i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,如果shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储,否则在Graph运行的时候,每个线程从队列取出元素是随机的,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch,slice的用法可以参考sclice用法。调用range_input_producer会生成一个输出整数的队列,同时与此队列对应的一个入队操作QueueRunner会自动加入当前图的 QUEUE_RUNNER集合中,即tf.GraphKeys.QUEUE_RUNNERS集合中,更详细的关于tensorflow的队列和多线程操作,可以参考队列和多线程。这里执行

threads = tf.train.start_queue_runners(sess=sess, coord=coord)

来启动多个线程来执行入队列操作,执行datalines = sess.run(inputs),根据出队列的索引获取指定切出位置的片段的数据值。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值