原文:http://blog.csdn.net/lyg5623/article/details/69387917
先放关键代码:
- i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
- inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
原理解析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
下面是完整的演示代码。
数据文件test.txt内容:
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
main.py内容:
- import tensorflow as tf
- import codecs
-
- BATCH_SIZE = 6
- NUM_EXPOCHES = 5
-
-
- def input_producer():
- array = codecs.open("test.txt").readlines()
- array = map(lambda line: line.strip(), array)
- i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
- inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
- return inputs
-
-
- class Inputs(object):
- def __init__(self):
- self.inputs = input_producer()
-
-
- def main(*args, **kwargs):
- inputs = Inputs()
- init = tf.group(tf.initialize_all_variables(),
- tf.initialize_local_variables())
- sess = tf.Session()
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
- sess.run(init)
- try:
- index = 0
- while not coord.should_stop() and index<10:
- datalines = sess.run(inputs.inputs)
- index += 1
- print("step: %d, batch data: %s" % (index, str(datalines)))
- except tf.errors.OutOfRangeError:
- print("Done traing:-------Epoch limit reached")
- except KeyboardInterrupt:
- print("keyboard interrput detected, stop training")
- finally:
- coord.request_stop()
- coord.join(threads)
- sess.close()
- del sess
-
- if __name__ == "__main__":
- main()
输出:
- step: 1, batch data: ['1' '2' '3' '4' '5' '6']
- step: 2, batch data: ['7' '8' '9' '10' '11' '12']
- step: 3, batch data: ['13' '14' '15' '16' '17' '18']
- step: 4, batch data: ['19' '20' '21' '22' '23' '24']
- step: 5, batch data: ['25' '26' '27' '28' '29' '30']
- Done traing:-------Epoch limit reached
如果range_input_producer去掉参数num_epochs=1,则输出:
- step: 1, batch data: ['1' '2' '3' '4' '5' '6']
- step: 2, batch data: ['7' '8' '9' '10' '11' '12']
- step: 3, batch data: ['13' '14' '15' '16' '17' '18']
- step: 4, batch data: ['19' '20' '21' '22' '23' '24']
- step: 5, batch data: ['25' '26' '27' '28' '29' '30']
- step: 6, batch data: ['1' '2' '3' '4' '5' '6']
- step: 7, batch data: ['7' '8' '9' '10' '11' '12']
- step: 8, batch data: ['13' '14' '15' '16' '17' '18']
- step: 9, batch data: ['19' '20' '21' '22' '23' '24']
- step: 10, batch data: ['25' '26' '27' '28' '29' '30']
有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:
- InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
- [[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]
错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。