TensorFlow. ——tf.train.queue读取数据代码
- 下面是tf.train.queue读取数据的基础代码,可以按需求补充和修改。
import tensorflow as tf
class TfTrainQueue(object):
"""tf.train.queue读取数据方法"""
def __init__(self, im_size):
self._im_size = im_size
def train(self, data, batch_size, num_epochs):
img1_batch, img2_batch, label1_batch, label2_batch = self.get_batch(data, batch_size, num_epochs)
with tf.Session as sess:
# 先执行初始化工作
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# 开启一个协调器
coord = tf.train.Coordinator()
# 使用start_queue_runners 启动队列填充
threads = tf.train.start_queue_runners(sess, coord)
try:
while not coord.should_stop():
img1, img2, label1, label2 = sess.run([img1_batch, img2_batch, label1_batch, label2_batch])
except tf.errors.OutOfRangeError: # 如果读取到文件队列末尾会抛出此异常
print("done! now lets kill all the threads……")
finally:
# 协调器coord发出所有线程终止信号
coord.request_stop() # all threads are asked to stop!
coord.join(threads) # 把开启的线程加入主线程,等待threads结束
def get_batch(self, data, batch_size, num_epochs):
def read_image(file_list):
image = tf.train.string_input_producer(file_list, shuffle=False, num_epochs=num_epochs)
reader = tf.WholeFileReader()
key, value = reader.read(image)
# 对读取的图片数据进行解码
image = tf.image.decode_jpeg(value)
image_resize = tf.image.resize_images(image, [self._im_size, self._im_size])
image_resize.set_shape([self._im_size, self._im_size, 3])
image_resize = tf.cast(image_resize, dtype=tf.float32) * (1. / 255) * 2 - 1
return image_resize
img1, img2, label1, label2 = data
img1_que = read_image(img1)
img2_que = read_image(img2)
label1_que, label2_que = tf.train.slice_input_producer(
[label1, label2],
shuffle=False,
num_epochs=num_epochs)
img1_batch, img2_batch, label1_batch, label2_batch = tf.train.batch(
[img1_que, img2_que, label1_que, label2_que],
batch_size=batch_size, num_threads=1,
capacity=batch_size)
return img1_batch, img2_batch, label1_batch, label2_batch