import tensorflow as tf
path = "train.csv"
record_defaults = [[0.0], [0.0], [0], [""], [""], [0.0], [0.0], [0.0], [""], [0.0], [""], [""]]
batch_size = 10
filename_queue = tf.train.string_input_producer(num_epochs=1, string_tensor=[path])
reader = tf.TextLineReader(skip_header_lines=1)
key, value = reader.read(filename_queue)
decoded = tf.decode_csv(value, record_defaults=record_defaults)
batch_data = tf.train.shuffle_batch(decoded,
batch_size=batch_size,
capacity=batch_size * 20,
min_after_dequeue=batch_size,
num_threads=5)
count = 0
with tf.Session() as sess:
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
try:
while not coord.should_stop():
count += 1
print(sess.run([batch_data]))
print(count)
except tf.errors.OutOfRangeError:
print("training stop, input queue is empty")
finally:
coord.request_stop()
print("count= ", count)
coord.request_stop()
coord.join(threads)
本代码使用数据集为Titanic数据集。
注意:在使用 tf.train.string_input_producer 读取数据时,如果不讲num_epochs设置,就会出现线程不断从文件中循环读取,从打印上可以明显看到。因此,导致程序运行时,出现陷入死循环的情况,无法触发outofrangeerror。