tf.train.slice_input_producer()
tf.train.slice_input_producer是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取一个tensor放入队列。
函数:
slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None)
参数:
tensor_list:包含一系列tensor的列表,表中tensor的第一维度的值必须相等,即个数必须相等,有多少个图像,就应该有多少个对应的标签。
num_epochs: 可选参数,是一个整数值,代表迭代的次数,如果设置 num_epochs=None,生成器可以无限次遍历tensor列表,如果设置为 num_epochs=N,生成器只能遍历tensor列表N次。
shuffle:bool类型,设置是否打乱样本的顺序。一般情况下,如果shuffle=True,生成的样本顺序就被打乱了,在批处理的时候不需要再次打乱样本,使用 tf.train.batch函数就可以了;如果shuffle=False,就需要在批处理时候使用 tf.train.shuffle_batch函数打乱样本。
seed: 可选的整数,是生成随机数的种子,在第三个参数设置为shuffle=True的情况下才有用。
capacity:设置tensor列表的容量。
shared_name:可选参数,如果设置一个‘shared_name’,则在不同的上下文环境(Session)中可以通过这个名字共享生成的tensor。
name:可选,设置操作的名称。
tf.train.start_queue_runners()
TensorFlow的Session对象是支持多线程的,可以在同一个会话(Session)中创建多个线程,并行执行。在Session中的所有线程都必须能被同步终止,异常必须能被正确捕获并报告,会话终止的时候,队列必须能被正确地关闭。
TensorFlow提供了两个类来实现对Session中多线程的管理:tf.Coordinator和 tf.QueueRunner,这两个类往往一起使用。
Coordinator类用来管理在Session中的多个线程,可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,该线程捕获到这个异常之后就会终止所有线程。使用tf.train.Coordinator()来创建一个线程管理器(协调器)对象。
QueueRunner类用来启动tensor的入队线程,可以用来启动多个工作线程同时将多个tensor(训练数据)推送入文件名称队列中,具体执行函数是 tf.train.start_queue_runners , 只有调用tf.train.start_queue_runners 之后,才会真正把tensor推入内存序列中,供计算单元调用,否则会由于内存序列为空,数据流图会处于一直等待状态。
代码:
import tensorflow as tf
# 定义四个图片路径列表
images = ['image1', 'image2', 'image3', 'image4']
# 定义四个label标签的列表
labels = [1, 2, 3, 4]
#产生图像和标签对应的tensor
[images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=2, shuffle=True)
with tf.Session() as sess:
# 对全局的变量进行初始化
sess.run(tf.local_variables_initializer())
# 放入线程后,才会产生张量
tf.train.start_queue_runners(sess=sess)# 启动队列填充的线程
for i in range(8):
# 从文件队列中获取数据,总共4条,num_epochs=2,所以最多取8条
print(sess.run([images, labels]))
'''
[b'image4.jpg', 4]
[b'image2.jpg', 2]
[b'image3.jpg', 3]
[b'image1.jpg', 1]
[b'image4.jpg', 4]
[b'image2.jpg', 2]
[b'image3.jpg', 3]
[b'image1.jpg', 1]
'''
tf.train.string_input_producer()
函数和tf.train.slice_input_producer()类似,不过是针对文件的生成器,传入文件路径列表,每次吐出一个文件
代码:
'''
读入文件信息
'''
import tensorflow as tf
# 文件相应位置
filename = ['A.csv','B.csv', 'C.csv']
# 将文件的路径作为参数传入函数
# 输出是文件队列,无法直接获取文件的值
file_queue = tf.train.string_input_producer(filename, shuffle=True, num_epochs=2)
# 文件读取器
reader = tf.WholeFileReader()
# key:文件名 value:文件值
key, value = reader.read(file_queue)
with tf.Session() as sess:
# 对变量进行赋值
sess.run(tf.local_variables_initializer())
# 定义文件队列填充的线程
tf.train.start_queue_runners(sess=sess)
for i in range(6):
# 文件数量 3 * 2epochs
print(sess.run([key, value]))
'''
[b'A.csv', b'1']
[b'B.csv', b'2']
[b'C.csv', b'3']
[b'A.csv', b'1']
[b'B.csv', b'2']
[b'C.csv', b'3']
'''
参考:https://blog.csdn.net/keyandi/article/details/103683761
转载请注明转自:https://blog.csdn.net/Owen_goodman/article/details/107709645