官方文档链接:https://tensorflow.google.cn/versions/r1.8/api_docs/python/tf/train/batch
-
tf.train.batch(
-
tensors,
-
batch_size,
-
num_threads=
1,
-
capacity=
32,
-
enqueue_many=
False,
-
shapes=
None,
-
dynamic_pad=
False,
-
allow_smaller_final_batch=
False,
-
shared_name=
None,
-
name=
None
-
)
函数功能:利用一个tensor的列表或字典来获取一个batch数据
参数介绍:
- tensors:一个列表或字典的tensor用来进行入队
- batch_size:设置每次从队列中获取出队数据的数量
- num_threads:用来控制入队tensors线程的数量,如果num_threads大于1,则batch操作将是非确定性的,输出的batch可能会乱序
- capacity:一个整数,用来设置队列中元素的最大数量
- enqueue_many:在tensors中的tensor是否是单个样本
- shapes:可选,每个样本的shape,默认是tensors的shape
- dynamic_pad:Boolean值.允许输入变量的shape,出队后会自动填补维度,来保持与batch内的shapes相同
- allow_samller_final_batch:可选,Boolean值,如果为True队列中的样本数量小于batch_size时,出队的数量会以最终遗留下来的样本进行出队,如果为Flalse,小于batch_size的样本不会做出队处理
- shared_name:可选,通过设置该参数,可以对多个会话共享队列
- name:可选,操作的名字
从数组中每次获取一个batch_size的数据
-
import numpy
as np
-
import tensorflow
as tf
-
-
def next_batch():
-
datasets = np.asarray(range(
0,
20))
-
input_queue = tf.train.slice_input_producer([datasets],shuffle=
False,num_epochs=
1)
-
data_batchs = tf.train.batch(input_queue,batch_size=
5,num_threads=
1,
-
capacity=
20,allow_smaller_final_batch=
False)
-
return data_batchs
-
-
if __name__ ==
"__main__":
-
data_batchs = next_batch()
-
sess = tf.Session()
-
sess.run(tf.initialize_local_variables())
-
coord = tf.train.Coordinator()
-
threads = tf.train.start_queue_runners(sess,coord)
-
try:
-
while
not coord.should_stop():
-
data = sess.run([data_batchs])
-
print(data)
-
except tf.errors.OutOfRangeError:
-
print(
"complete")
-
finally:
-
coord.request_stop()
-
coord.join(threads)
-
sess.close()
注意:tf.train.batch这个函数的实现是使用queue,queue的QueueRunner被添加到当前计算图的"QUEUE_RUNNER"集合中,所在使用初始化器的时候,需要使用tf.initialize_local_variables(),如果使用tf.global_varialbes_initialize()时,会报: Attempting to use uninitialized value
更多参考:
https://blog.csdn.net/weixin_44606212/article/details/88644327