tf.train.slice_input_producer()、tf.train.batch()、tf.train.shuffle_batch()函数

Tensorflow利用slice_input_producer创建文件名队列,tf.train.batch按顺序出队数据,shuffle_batch实现乱序出队。文件名队列在每个epoch结束后抛出OutOfRange异常,slice_input_producer参数包括tensor_list、num_epochs、shuffle等。tf.train.batch处理batch_size、num_threads、capacity等,shuffle_batch关注min_after_dequeue确保数据混合。
摘要由CSDN通过智能技术生成

Tensorflow的数据读取机制:

tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算。
具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程负责计算任务,所需数据直接从内存队列中获取。
tf在内存队列之前,还设立了一个文件名队列,文件名队列存放的是参与训练的文件名,要训练 N个epoch,则文件名队列中就含有N个批次的所有文件名。
在N个epoch的文件名最后是一个结束标志,当tf读到这个结束标志的时候,会抛出一个 OutofRange 的异常,外部捕获到这个异常之后就可以结束程序了。而创建tf的文件名队列就需要使用到 tf.train.slice_input_producer 函数。
在这里插入图片描述

slice_input_producer() 创建文件名队列

slice_input_producer(tensor_list, num_epochs=None, shuffle=False, seed=None,
                         capacity=32, shared_name=None, name=None)

tensor_list:包含一系列tensor的列表,表中tensor的第一维度的值必须相等,即个数必须相等,有多少个数据,就应该有多少个对应

假设我们有一个包含100个样本的数据集,每个样本有两个特征,一个是图像数据,一个是标签。我们希望使用TensorFlow的队列机制异步读取这些数据,并进行训练。 首先,我们可以使用tf.train.slice_input_producer函数将数据集切分成若干个batch,然后每个batch通过多个线程异步读取数据: ```python import tensorflow as tf # 构造数据集 data = [] for i in range(100): image = ... # 加载图像数据 label = ... # 加载标签数据 data.append((image, label)) # 定义batch大小和线程数 batch_size = 32 num_threads = 4 # 使用slice_input_producer函数将数据集切分成若干个batch image_batch, label_batch = tf.train.slice_input_producer(data, batch_size=batch_size, num_threads=num_threads) # 定义数据预处理函数 def preprocess(image, label): # 对图像数据进行预处理 image = ... # 对标签数据进行预处理 label = ... return image, label # 使用map函数将数据预处理函数应用到每个batch中的每个样本 image_batch, label_batch = tf.map_fn(preprocess, (image_batch, label_batch)) # 定义模型 ... # 定义损失函数 ... # 定义优化器 ... # 定义训练操作 train_op = ... # 启动会话 with tf.Session() as sess: # 初始化变量 sess.run(tf.global_variables_initializer()) # 启动多线程读取数据 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) # 训练模型 for i in range(num_steps): _, loss_val = sess.run([train_op, loss]) # 关闭多线程 coord.request_stop() coord.join(threads) ``` 在上面的代码中,我们首先定义了一个包含100个样本的数据集。然后,使用tf.train.slice_input_producer函数将数据集切分成若干个batch,并通过多个线程异步读取数据。接着,我们定义了一个数据预处理函数,并使用tf.map_fn函数将其应用到每个batch中的每个样本。最后,我们定义了模型、损失函数和优化器,并使用tf.Session启动会话进行训练。在训练过程中,我们启动多线程读取数据,并在训练完成后关闭多线程。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值