Tensorflow通过tf.train.Coordinator和tf.train.QueueRunner来完成。tf.train.Coordinator的功能为协同管理多线程的功能,例如一起工作,一起停止。Coordinator提供了三个函数:should_stop、request_step、join。tf.train.QueueRunner注意用于启动多个线程来操作同一个队列,而线程的启动还是依靠tf.train.Coordinator完成。
should_stop:当should_stop=True时,停止所有现成。
request_step:让某个现成使用request_step时,should_stop将会被置为True,停止所有线程。
join:就像python threading中的join是一样的,等待所有线程退出。
至于它是如何使用的呢?就像”TFrecode格式读取“中的一样,需要首先声明一个tf.train.Coordinator()类,然后启动所有线程。
# 可以使用Tensorflow的多线程进行读取
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord
~~~~~~
# 停止所有线程
coord.request_stop()
coord.join(threads)
接下来以Google inception V3训练ImageNet时中的代码为例,大致流程为:
1、获取文件列表、2、获取文件队列、3、定义多个线程从文件队列中读取文件、4、队列出队、5、定义多线程对出队的文件进行处理、6、组合预处理之后的数据。
开始上代码:
1、获取文件列表
def data_files(self):
"""Returns a python list of all (sharded) data subset files.
Returns:
python list of all (sharded) data set files.
Raises:
ValueError: if there are not data_files matching the subset.
"""
tf_record_pattern = os.path.join(FLAGS.data_dir, '%s-*' % self.subset)
data_files = tf.gfile.Glob(tf_record_pattern)
if not data_files:
print('No files found for dataset %s/%s at %s' % (self.name,
self.subset,
FLAGS.data_dir))
self.download_message()
exit(-1)
return data_files
data_files = dataset.data_files()
data_files即为包含所有训练TFRecord的文件,文件名以train-?????-of-01024。
2、获取文件队列
# 文件队列的capacity=16
# shuffle=True,将文件顺序打乱,保证训练数据是随机的。
filename_queue = tf.train.string_input_producer(data_files,
shuffle=True,
capacity=16)
3、定义多个线程从文件队列中读取文件
# 使用RandomShuffleQueue,队列出队的顺序是随机的。
examples_queue = tf.RandomShuffleQueue(
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples,
dtypes=[tf.string])
# 定义多线程从文件队列中读取文件,文件队列会均匀的将文件分配给不同的线程
enqueue_ops = []
for _ in range(num_readers):
reader = dataset.reader()
_, value = reader.read(filename_queue)
# 队列的进队操作
enqueue_ops.append(examples_queue.enqueue([value]))
4、队列出队
在这里需要使用QueueRunner启动多个线程操作队列。这里面是启动len(enqueue_ops)个线程进行入队操作。
tf.train.queue_runner.add_queue_runner(
tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
# define the out queue
example_serialized = examples_queue.dequeue()
5、定义多线程对出队的文件进行处理
images_and_labels = []
for thread_id in range(num_preprocess_threads):
# Parse a serialized Example proto to extract the image and metadata.
image_buffer, label_index, bbox, _ = parse_example_proto(
example_serialized)
# 对图像进行预处理
image = image_preprocessing(image_buffer, bbox, train, thread_id)
images_and_labels.append([image, label_index])
6、组合预处理之后的数据。
images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=batch_size,
capacity=2 * num_preprocess_threads * batch_size)
经过组合后的数据,就可以做为神经网络的输入了,开启训练之旅。