tf.train.batch 将数据 batch 化.
定义:
函数 tf.train.batch
:
tf.train.batch(tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None )
作用:
创建 tensors
的 batches.
参数:
- tensors - 送入队列的 tensors 列表或字典.
tf.train.batch
函数返回值是相同类型的tensors
. - batch_size - 从队列拉取的样本的 batch size
- num_threads - 入队
tensors
的线程数. 如果num_thread>1
,则 batch 操作是非确定的. - capacity - 整数,队列容量,队列里样本元素的最大数.
- enqueue_many - tensors 内的每个 tensor 是否是单个样本.
- shapes - (可选)每个样本的 shape. 默认是
tensors
的shapes. - dynamic_pad - Boolean 值. 输入 shapes 的变量维度. 出队后会自动填补维度,以保持batch内 shapes 一致.
- allow_smaller_final_batch - (可选) Boolean 值. 如果队列中样本不足 batch,允许最后的 batch 样本数小于 batch_size.
- shard_name - (可选). 如果设置了该参数,则在多个会话给定的名字时,共享队列.
- name - (可选). 操作operations 的名字.
简单说明:
该函数的输入 tensors
是张量tensors列表或字典,且函数返回相同类型的 tensors
.
该函数采用队列queue 来实现. 队列的 QueueRunner
被添加到当前 Grahp
的 QUEUE_RUNNER
集合(collection) 中.
如果 enqueue_many=False
,则 tensor
表示单个样本.
对于 shape 为 [x, y, z]
的输入 tensor,该函数输出为,shape 为 [batch_size, x, y, z]
的 tensor.
如果 enqueue_many=True
,则 tensors
表示 batch 个样本,其中,第一维表示样本的索引,所有的 tensors
都在第一维具有相同的尺寸.
对于 shape 为 [*, x, y, z]
的输入 tensor,该函数输出为,shape 为 [batch_size, x, y, z]
的 tensor.
capacity
参数控制着预取队列的长度(how long the prefetching is allowed to grow the queues),队列容量.
如果输入队列用完,则返回 tf.errors.OutofRangeError
.
如果 dynamic_pad=False
,则必须保证 shapes
参数被传递,或 tensors
内的所有张量必须已经预定义 shapes. 否则,会出现 ValueError
.
如果 dynamic_pad=True
,则张量的秩已知即可,但独立维度的 shape 为 None
. 此时,每次入队时,维度为 None
的值的长度是可变的. 出队后,输出的 tensors 会根据当前 minibatch 内的 tensors 的最大 shape 来在右边自动补零. 对于数字 tensors,填补的值为 0;对于字符串 tensors,填补的是空字符. 可见 PaddingFIF0Queue
.
如果 allow_smaller_final_batch=True
,当队列关闭时,如果没有足够的样本元素来填补 batch,则会返回比 batch_size
更小的 batch 值,否则会丢弃样本元素.
示例:
tf.train.batch([example, label], batch_size=batch_size, capacity=capacity)
- [example, label] - 样本和标签,可以是单个样本和单个标签
- batch_size - 返回的一个 batch 的样本数.
- capacity - 队列容量. 按顺序组合成一个batch.
tf.train.batch 使用示例
转自: 关于Tensorflow中的tf.train.batch函数
tensorflow中的读取数据的队列,简单的说,就是计算图是从一个管道中读取数据的,录入管道是用的现成的方法,读取也是.
为了保证多线程的时候从一个管道读取数据不会乱,所以这种时候读取的时候需要线程管理的相关操作。
测试了一个简单的操作,就是给一个有序的数据,看看读出来是不是有序的,结果发现是有序的,代码如下:
import tensorflow as tf
import numpy as np
def generate_data():
num = 25
label = np.asarray(range(0, num))
images = np.random.random([num, 5, 5, 3])
print('label size :{}, image size {}'.format(label.shape, images.shape))
return label, images
def get_batch_data():
label, images = generate_data()
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
input_queue = tf.train.slice_input_producer([images, label], shuffle=False) # 默认 shuffle=True
image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)
return image_batch, label_batch
image_batch, label_batch = get_batch_data()
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
i = 0
try:
while not coord.should_stop():
image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
i += 1
for j in range(10):
print(image_batch_v.shape, label_batch_v[j])
except tf.errors.OutOfRangeError:
print("done")
finally:
coord.request_stop()
coord.join(threads)
最后修改:2018 年 10 月 09 日 09 : 31 AM