TensorFlow读出TFRecord中的数据,然后再经过预处理操作,此时需要注意:数据还是单个,而网络的输入一般以Batch为单位,因此我们需要将单个的数据组合成一个Batch,做为神经网络的输入。
TensorFlow提供组合训练数据的函数有四个:tf.train.batch()
,tf.train.shuffle_batch()
与tf.train.batch_join
、tf.train.shuffle_batch_join
。
最近花了很久理解tf.train.batch()
和tf.train.batch_join
的输入维度与输出维度之间的关系,真的是很头大,当我认真研读了TensorFlow的官方文档之后发现了一些玄机。官方文档的每个单词的每个字母都不能忽略,尤其是函数的参数的单复数往往藏着玄机。
首先看tf.train.batch
这个函数:
tf.train.batch(
tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
注意这个函数的第一个参数是tensors,是一个复数,也就是说它将你的输入看做很多个tensor组成的tensors,比如我输入一个shape为[5,4,3,2]的list,它就将这个tensors看做5个[4,3,2]的tensor。理解了这一点之后对输出维度的理解就不难了。
对enqueue_many
的理解:
If enqueue_many is False, tensors is assumed to represent a single example. An input tensor with shape [x, y, z] will be output as a tensor with shape [batch_size, x, y, z].
If enqueue_many is True, tensors is assumed to represent a batch of examples, where the first dimension is indexed by example, and all members of tensors should have the same size in the first dimension. If an input tensor has shape [*, x, y, z], the output will have shape [batch_size, x, y, z]. The capacity argument controls the how long the prefetching is allowed to grow the queues.
注意官方文档里的an input tensor是单数的,如果你把整个tensors看做一个tensor,那么输出的维度和你预想的就很难对上,个人理解如下:
如果enqueue_many
设置为False,tensors中的每个tensor被认为代表单个样本。那么输入维度(shape)为[x,y,z]的tensor,将会输出一个维度为[batch_size,x,y,z]的张量。
如果enqueue_many
设置为True,参数tensors中的每个tensor被认为是一批次的样本,其中第一维是按样本编索引的,如果输入的tensor的维度是[*,x,y,z],那么输出的张量的维度将会是[batch_size,x,y,z]。
比如输入shape为[5,4,3,2]的tensors,每个tensor的shape为[4,3,2],设batch_size
为4,当enqueue_many
设置为False时,每个tensor的输出的shape为[4,4,3,2],那么总的输出为[5,4,4,3,2]。当enqueue_many
设置为True时,每个tensor被认为是一个batch的样本,那么它的输出为[4,3,2],总的输出为[5,4,3,2]
代码实现:
import tensorflow as tf
# shape为[5,4,3,2]的tensors
tensors = [[[[1,2],[3,4],[5,6]],[[7,8],[9,10],[11,12]],[[13,14],[15,16],[17,18]],[[19,20],[21,22],[23,24]]], [[[25,26],[27,28],[29,30]],[[31,32],[33,34],[35,36]],[[37,38],[39,40],[41,42]],[[43,44],[45,46],[47,48]]], [[[