tensorflow中协调器 tf.train.Coordinator 和入队线程启动器 tf.train.start_queue_runners
ensorFlow的Session对象是支持多线程的,可以在同一个会话(Session)中创建多个线程,并行执行。在Session中的所有线程都必须能被同步终止,异常必须能被正确捕获并报告,会话终止的时候, 队列必须能被正确地关闭。
TensorFlow提供了两个类来实现对Session中多线程的管理:tf.Coordinator和 tf.QueueRunner,这两个类往往一起使用。
Coordinator类用来管理在Session中的多个线程,可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,该线程捕获到这个异常之后就会终止所有线程。使用 tf.train.Coordinator()来创建一个线程管理器(协调器)对象。
QueueRunner类用来启动tensor的入队线程,可以用来启动多个工作线程同时将多个tensor(训练数据)推送入文件名称队列中,具体执行函数是 tf.train.start_queue_runners , 只有调用 tf.train.start_queue_runners 之后,才会真正把tensor推入内存序列中,供计算单元调用,否则会由于内存序列为空,数据流图会处于一直等待状态。
tf中的数据读取机制如下图:
调用 tf.train.slice_input_producer,从 本地文件里抽取tensor,准备放入Filename Queue(文件名队列)中;
调用 tf.train.batch,从文件名队列中提取tensor,使用单个或多个线程,准备放入文件队列;
调用 tf.train.Coordinator() 来创建一个线程协调器,用来管理之后在Session中启动的所有线程;
调用tf.train.start_queue_runners, 启动入队线程,由多个或单个线程,按照设定规则,把文件读入Filename Queue中。函数返回线程ID的列表,一般情况下,系统有多少个核,就会启动多少个入队线程(入队具体使用多少个线程在tf.train.batch中定义);
文件从 Filename Queue中读入内存队列的操作不用手动执行,由tf自动完成;
调用sess.run 来启动数据出列和执行计算;
使用 coord.should_stop()来查询是否应该终止所有线程,当文件队列(queue)中的所有文件都已经读取出列的时候,会抛出一个 OutofRangeError 的异常,这时候就应该停止Sesson中的所有线程了;
使用coord.request_stop()来发出终止所有线程的命令,使用coord.join(threads)把线程加入主线程,等待threads结束。
# -*- coding:utf-8 -*-
__author__ = 'kang'
# @Author :2020/12/28
# @time :9:58
# @File :代码4-5 将图片文件制作成TFRecord数据集.py
# @Software:PyCharm
import os
import tensorflow as tf
from PIL import Image
from sklearn.utils import shuffle
import numpy as np
from tqdm import tqdm
tf.compat.v1.disable_eager_execution()
def load_sample(sample_dir,shuffleflag=True):
print('loading sample dataset')
lfilenames = []
labelsnames = []
for (dirpath,dirnames,filenames) in os.walk(sample_dir):
for filename in filenames:
filename_apth = os.sep.join([dirpath,filename])
lfilenames.append(filename_apth)
labelsnames.append(dirpath.split('\\')[-1])
lab = list(sorted(set(labelsnames)))
labdict = dict(zip(lab,list(range(len(lab)))))
labels = [labdict[i] for i in labelsnames]
if shuffleflag == True:
return shuffle(np.asarray( lfilenames),np.asarray( labels)),np.asarray(lab)
else:
return (np.asarray(lfilenames), np.asarray(labels)), np.asarray(lab)
def makeTFRec(filenames,labels):
writer = tf.io.TFRecordWriter('mydata.tfrecords')
for i in tqdm(range(0,len(labels))):
img = Image.open(filenames[i])
img = img.resize((256,256))
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
#存放图片的标签label
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]])),
#存放具体的图片
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) #example对象对label和image数据进行封装
writer.write(example.SerializeToString()) #序列化为字符串
writer.close() #数据集制作完成
################将tf数据集转化为图片##########################
def read_and_decode(filenames, flag='train',batch_size = 3):
# 根据文件名生成一个队列
if flag == 'train':
filename_queue = tf.compat.v1.train.string_input_producer(filenames) # 默认已经是shuffle并且循环读取
else:
filename_queue = tf.compat.v1.train.string_input_producer(filenames, num_epochs=1, shuffle=False)
reader = tf.compat.v1.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.io.parse_single_example(serialized=serialized_example, #取出包含image和label的feature对象
features={
'label': tf.io.FixedLenFeature([], tf.int64),
'img_raw' : tf.io.FixedLenFeature([], tf.string),
})
# tf.decode_raw可以将字符串解析成图像对应的像素数组
image = tf.io.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [256, 256, 3])
#
label = tf.cast(features['label'], tf.int32)
if flag == 'train':
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 # 归一化
img_batch, label_batch = tf.compat.v1.train.batch([image, label], # 还可以使用tf.train.shuffle_batch进行乱序批次
batch_size=batch_size, capacity=20)
# img_batch, label_batch = tf.train.shuffle_batch([image, label],
# batch_size=batch_size, capacity=20,
# min_after_dequeue=10)
return img_batch, label_batch
return image, label
if __name__ == '__main__':
directory = r'D:\学习的书籍\深度学习之Tensorflow工程项目优化\04data\第4章 配套资源\第4章 配套资源\man_woman' # 定义样本路径
(filenames, labels), _ = load_sample(directory, shuffleflag=False) # 载入文件名称与标签
makeTFRec(filenames, labels)
TFRecordfilenames = ["mydata.tfrecords"]
image, label = read_and_decode(TFRecordfilenames, flag='test') # 以测试的方式打开数据集
saveimgpath = 'show\\' # 定义保存图片路径
if tf.io.gfile.exists(saveimgpath):
tf.io.gfile.rmtree(saveimgpath)
tf.io.gfile.makedirs(saveimgpath)
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.local_variables_initializer())
#启用多线程
coord = tf.compat.v1.train.Coordinator()
threads = tf.compat.v1.train.start_queue_runners(coord=coord)
myset = set([])
try:
i=0
while True:
example,examplelab = sess.run([image,label])#在会话中取出image,label
examplelab = str(examplelab)
if examplelab not in myset:
myset.add(examplelab)
tf.io.gfile.makedirs(saveimgpath + examplelab)
print(saveimgpath + examplelab, i)
img = Image.fromarray(example, 'RGB') # 转换Image格式
img.save(saveimgpath + examplelab + '/' + str(i) + '_Label_' + '.jpg') # 存下图片
print(i)
i = i + 1
except tf.errors.OutOfRangeError:
print('')
finally:
coord.request_stop()
coord.join()