问题描述
在TensorFlow Session作用域中通过eval打印出Tensor,发现eval被阻塞(never end)。
复现场景
#encoding=utf-8
import tensorflow as tf
def read_tfrecord():
record_filename = "./data/dog_image.tfrecord"
tf_record_filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(record_filename))
tf_record_reader = tf.TFRecordReader()
key, tf_record_serialized = tf_record_reader.read(tf_record_filename_queue)
tf_record_feature = tf.parse_single_example(tf_record_serialized,
features={
'label': tf.FixedLenFeature([], tf.string),
'image': tf.FixedLenFeature([], tf.string)
})
tf_record_image = tf.decode_raw(tf_record_feature['image'], tf.uint8)
tf_record_label = tf.cast(tf_record_feature['label'], tf.string)
return tf_record_image, tf_record_label
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess=sess, coord=coord)
tf_record_image1, tf_record_label1 = read_tfrecord()
print("before eval")
print(tf_record_image1.eval()) #A
print(tf_record_label1.eval()) #B
print("end eval") #C
coord.request_stop()
coord.join(thread)
执行此段代码会发现代码A和B处不会出现任何输出,而且代码C处也未输出,说明代码阻塞在eval函数执行处。
解决方法
把上述代码中的read_tfrecord函数改为如下形式即可解决。
record_filename = "./data/dog_image.tfrecord"
#put queue out of read_tfrecord function
tf_record_filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(record_filename))
def read_tfrecord():
tf_record_reader = tf.TFRecordReader()
key, tf_record_serialized = tf_record_reader.read(tf_record_filename_queue)
tf_record_feature = tf.parse_single_example(tf_record_serialized,
features={
'label': tf.FixedLenFeature([], tf.string),
'image': tf.FixedLenFeature([], tf.string)
})
tf_record_image = tf.decode_raw(tf_record_feature['image'], tf.uint8)
tf_record_label = tf.cast(tf_record_feature['label'], tf.string)
return tf_record_image, tf_record_label
具体原因未知,希望有知道的小伙伴告诉一下。
另一个原因
You must call tf.train.start_queue_runners(sess) before you call train_data.eval() or train_labels.eval().
This is a(n unfortunate) consequence of how TensorFlow input pipelines are implemented: the tf.train.string_input_producer(), tf.train.shuffle_batch(), and tf.train.batch() functions internally create queues that buffer records between different stages in the input pipeline. The tf.train.start_queue_runners() call tells TensorFlow to start fetching records into these buffers; without calling it the buffers remain empty and eval() hangs indefinitely.