tensorflow的官方文档中对
tf.train.shuffle_batch()
这个函数有如下的描写
这个函数的功能是:Creates batches by randomly shuffling tensors.
但需要注意的是它是一种图运算,要跑在sess.run()里
- A shuffling queue into which tensors from tensors are enqueued.
- 一个乱序的队列,进队的正是传入的tensors
- A dequeue_many operation to create batches from the queue.
- 一个dequeue_many的操作从队列中推出成batch的tensor
- A QueueRunner to QUEUE_RUNNER collection, to enqueue the tensors from tensors.
- 一个QueueRunner的线程,正是这个线程将传入的数据推进队列中.
把数据放在队列里有很多好处,可以完成训练数据和测试数据的解耦,同时有利于写成分布式训练(个人理解),但需要注意的是在取数据的时候,容易造成堵塞的情况.
这时候,应该需要截获超时异常来强制停止线程.
Coordinator
Coordinator类用来帮助多个线程协同工作,多个线程同步终止。 其主要方法有:
should_stop()
:如果线程应该停止则返回True。
request_stop(<exception>)
: 请求该线程停止。
join(<list of threads>)
:等待被指定的线程终止。
首先创建一个Coordinator
对象,然后建立一些使用Coordinator
对象的线程。这些线程通常一直循环运行,一直到should_stop()
返回True时停止。 任何线程都可以决定计算什么时候应该停止。它只需要调用request_stop()
,同时其他线程的should_stop()
将会返回True
,然后都停下来。
This function adds the following to the current Graph:
#!/usr/bin/env python
# -*-coding: utf-8-*-
import tensorflow as tf
COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "gender",
"capital_gain", "capital_loss", "hours_per_week", "native_country",
"income_bracket", "label"]
def read_file_format(filename_queue):
"""
读文件数据, 执行一次读一行
:param filename_queue:
:return:
"""
# 定义Reader
reader = tf.TextLineReader()
# Each execution of read reads a single line from the file. 执行一次读一行
key, record_string = reader.read(filename_queue)
# 定义Decoder
# record_defaults:空值补的默认值
record_defaults = [[''], [''], [''], [''], [''], [''], [''], [''], [''], [''], [''], [''], [''], [''], ['']]
cols = tf.decode_csv(record_string, record_defaults=record_defaults)
features = zip(COLUMNS, cols)
label = 1
return features, label
def input_pipeline(filenames, batch_size, num_epochs=None):
"""
输入管道,使用队列(queue)读取数据
:param filenames:
:param batch_size: 一批样本数
:param num_epochs:
:return:
"""
# creates a FIFO queue for holding the filenames until the reader needs them
filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
single_image, single_label = read_file_format(filename_queue)
# min_after_dequeue defines how big a buffer we will randomly sample
# from -- bigger means better shuffling but slower start up and more
# memory used.
# capacity must be larger than min_after_dequeue and the amount larger
# determines the maximum we will prefetch. Recommendation:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
# Creates batches by randomly shuffling tensors
# Creates batches of batch_size images and batch_size labels.
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[single_image, single_label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
train_file_name = "./data/adult.data"
test_file_name = "./data/adult.test"
example_batch, label_batch = input_pipeline([train_file_name, test_file_name], batch_size=10)
with tf.Session() as sess:
# Start populating the filename queue.
# Coordinator类用来帮助多个线程协同工作,多个线程同步终止
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
train_steps = 10
# Retrieve a single instance:
try:
while not coord.should_stop(): # 如果线程应该停止则返回True
example, label = sess.run([example_batch, label_batch])
print (example)
train_steps -= 1
print train_steps
if train_steps <= 0:
coord.request_stop() # 请求该线程停止
except tf.errors.OutOfRangeError:
print ('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop. 请求该线程停止
coord.request_stop()
# And wait for them to actually do it. 等待被指定的线程终止
coord.join(threads)
sess.close()