tf.train.batch()函数中,主要需要关注的是四个参数
tensors, batch_size, num_threads, capacity
tensors – 训练数据的来源
batch_size – 每次从队列中取出的数据量
num_threads – 设置用来实现多线程读取
capacity – 队列中的数据量
训练的流程就是首先取出capacity数量的数据加入队列,然后再从队列中取出batch_size数量的数据用于训练,同时补充batch_size数量的数据到队列中。
如果使用的是tf.train.shuffle_batch()函数,则是取出capacity数量的数据加入队列,然后打乱顺序,再取出batch_size用于训练,在补充数据到队列中之后,再一次打乱队列中的数据,以便进行下一次的读取数据。
Tensorflow -- tf.train.batch函数
最新推荐文章于 2021-12-01 15:37:02 发布