陈伟@航天科技智慧城市研究院 chenwei@ascs.tech
tf.train.batch与tf.train.shuffle_batch的作用都是从队列中读取数据,它们的区别是是否随机打乱数据来读取。
tf.train.batch()函数
tf.train.batch(
tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
- tensors:一个列表或字典的tensor用来进行入队
- batch_size:每次从队列中获取出队数据的数量
- num_threads:用来控制入队tensors线程的数量,如果num_threads大于1,则batch操作将是非确定性的,输出的batch可能会乱序
- capacity:一个整数,用来设置队列中元素的最大数量
- enqueue_many:在tensors中的张量是否是单个样本,若为False,则认为tensors代表一个样本.输入张量形状为[x, y, z]时,则输出张量形状为[batch_size, x, y, z],若为True,则认为tensors代表一批样本,其中第一个维度为样本的索引,并且所有成员tensors在第一维中应具有相同大小.若输入张量形状为[*, x, y, z],则输出张量的形状为[batch_size, x, y, z]
- shapes:每个样本的shape,默认是tensors的shape
- dynamic_pad:为True时允许输入变量的shape,出队后会自动填补维度,来保持与batch内的shapes相同
- allow_smaller_final_batch:为True队列中的样本数量小于batch_size时,出队的数量会以最终遗留下来的样本进行出队,如果为Flalse,小于batch_size的样本不会做出队处理
- shared_name:如果设置,则队列将在多个会话中以给定名称共享
- name:操作的名字
代码演示
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
images = np.random.random([5, 2]) # 5x2的矩阵
print(images)
label = np.asarray