tf.train.batch(
tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
函数功能:利用一个tensor的列表或字典来获取一个batch数据
参数介绍:
tensors:一个列表或字典的tensor用来进行入队
batch_size:设置每次从队列中获取出队数据的数量
num_threads:用来控制入队tensors线程的数量,如果num_threads大于1,则batch操作将是非确定性的,输出的batch可能会乱序
capacity:一个整数,用来设置队列中元素的最大数量
enqueue_many:在tensors中的tensor是否是单个样本
shapes:可选,每个样本的shape,默认是tensors的shape
dynamic_pad:Boolean值.允许输入变量的shape,出队后会自动填补维度,来保持与batch内的shapes相同
allow_samller_final_batch:可选,Boolean值,如果为True队列中的样本数量小于batch_size时,出队的数量会以最终遗留下来的样本进行出队,如果为Flalse,小于batch_size的样本不会做出队处理
shared_name:可选,通过设置该参数,可以对多个会话共享队列
name:可选,操作的名字
import tensorflow as tf
import numpy as np
import cv2
def read_single_tfrecord(tfrecord_file, batch_size, net):
# generate a input queue
# each epoch shuffle
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer([tfrecord_file], shuffle=True)
# read tfrecord
# 读取tfrecord文件
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
image_features = tf.parse_single_example(
serialized_example,
features={
'image/encoded': tf.FixedLenFeature([], tf.string), # one image one record
'image/label': tf.FixedLenFeature([], tf.int64),
'image/roi': tf.FixedLenFeature([4], tf.float32),
'image/landmark': tf.FixedLenFeature([10], tf.float32)
}
)
if net == 'PNet':
image_size = 12
elif net == 'RNet':
image_size = 24
else:
image_size = 48
# tf.decode_raw可以将字符串解析成图像对应的像素数组
image = tf.decode_raw(image_features['image/encoded'], tf.uint8)
# reshape操作
image = tf.reshape(image, [image_size, image_size, 3])
# 图像进行归一化,范围为[-1, 1]
image = (tf.cast(image, tf.float32) - 127.5) / 128
# image = tf.image.per_image_standardization(image)
# 读取label、roi(region of interest)、landmark
label = tf.cast(image_features['image/label'], tf.float32)
roi = tf.cast(image_features['image/roi'], tf.float32)
landmark = tf.cast(image_features['image/landmark'], tf.float32)
# tf.train.batch() 按顺序读取队列中的数据
# batch_size:从队列中提取新的批量大小
# num_threads:线程数量.若批次是不确定 num_threads > 1
# capacity:队列中元素的最大数量
image, label, roi, landmark = tf.train.batch(
[image, label, roi, landmark],
batch_size=batch_size,
num_threads=2,
capacity=1 * batch_size
)
label = tf.reshape(label, [batch_size])
roi = tf.reshape(roi, [batch_size, 4])
landmark = tf.reshape(landmark, [batch_size, 10])
return image, label, roi, landmark
# 返回image, label, roi,landmark
实例:
TensorFlow提供了两个类来实现对Session中多线程的管理:tf.Coordinator和 tf.QueueRunner,这两个类往往一起使用。Coordinator类用来管理在Session中的多个线程,可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,该线程捕获到这个异常之后就会终止所有线程。使用 tf.train.Coordinator()来创建一个线程管理器(协调器)对象。QueueRunner类用来启动tensor的入队线程,可以用来启动多个工作线程同时将多个tensor(训练数据)推送入文件名称队列中,具体执行函数是 tf.train.start_queue_runners , 只有调用 tf.train.start_queue_runners 之后,才会真正把tensor推入内存序列中,供计算单元调用,否则会由于内存序列为空,数据流图会处于一直等待状态。
import numpy as np
import tensorflow as tf
def next_batch():
datasets = np.asarray(range(0, 30))
print(datasets)
input_queue = tf.train.slice_input_producer([datasets], shuffle=False, num_epochs=1)
data_batchs = tf.train.batch(input_queue, batch_size=10, num_threads=1,
capacity=30, allow_smaller_final_batch=False)
return data_batchs
if __name__ == "__main__":
data_batchs = next_batch()
sess = tf.Session()
sess.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
try:
while not coord.should_stop():
data = sess.run([data_batchs])
print(data)
except tf.errors.OutOfRangeError:
print("complete")
finally:
coord.request_stop()
coord.join(threads)