最近在看TensorFlow数据读取模块,有了一点思路,先把读取部分的调用过程写下来,以cifar10为例。
入口 cifar10_train.py
distorted_inputs() 函数执行数据读取
def train():
with tf.Graph().as_default():
......
# Get images and labels for CIFAR-10.
# 从二进制文件中读取数据 images, labels
images, labels = cifar10.distorted_inputs() # 1 -------------->
......
1 --------------> cifar10.py
cifar10.distorted_inputs() 是将 data_dir 定义后再调用 cifar10_input.distorted_inputs()
cifar10_input.distorted_inputs 是执行数据读取的主要函数
def distorted_inputs():
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=FLAGS.batch_size) # 2 -------------->
if FLAGS.use_fp16:
images = tf.cast(images, tf.float16)
labels = tf.cast(labels, tf.float16)
return images, labels
2 --------------> cifar10_input.py
cifar10_input.distorted_inputs() 主要有几部分组成
1. 生成文件名队列。使用 tf.train.string_input_producer()函数生成文件名队列,通过调用分支 3 进行具体的调用过程分析。
2. 文件读取与解析。通过在函数 read_cifar10() 中定义了对应文件类型的阅读器及解析器,并通过对应的 read 及 decode 方法得到样本数据,通过调用分支 4 进行具体的调用过程分析。
3. 样本处理(包括裁剪、翻转等)
4. 样本批处理。通过在函数 _generate_image_and_label_batch() 中设置线程数,并调用不同的批处理函数,进行数据的批处理,通过调用分支 5 进行具体的调用过程分析。
===================================================================================================
def distorted_inputs(data_dir, batch_size):
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# Create a queue that produces the filenames to read.
# 生成要读取的文件名队列
filename_queue = tf.train.string_input_producer(filenames) # 3 -------------->
# Read examples from files in the filename queue.
read_input = read_cifar10(filename_queue) # 4 -------------->
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = IMAGE_SIZE
width = IMAGE_SIZE
# Image processing for training the network. Note the many random
# distortions applied to the image.
# 为训练网络进行图像处理。注意应用于图像的许多随机失真。
# Randomly crop a [height, width] section of the image.
# 随机裁剪图像为 [height,width] 像素大小的图片
distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
# Randomly flip the image horizontally.
# 随意地水平翻转图像。
distorted_image = tf.image.random_flip_left_right(distorted_image)
# Because these operations are not commutative, consider randomizing
# the order their operation.
# 因为这些操作是不可交换的,所以请考虑将它们的操作随机化。
# 随机的改变图片的亮度
distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
# 随机的改变图片的对比度
distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)
# Subtract off the mean and divide by the variance of the pixels.
# 图像的白化:减去平均值并除以像素的方差,均值与方差的均衡,降低图像明暗、光照差异引起的影响
float_image = tf.image.per_image_standardization(distorted_image)
# Set the shapes of tensors.
float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])
# Ensure that the random shuffling has good mixing properties.
# 确保随机 shuffling 具有良好的混合性能。
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
min_fraction_of_examples_in_queue)
print ('Filling queue with %d CIFAR images before starting to train. '
'This will take a few minutes.' % min_queue_examples)
# Generate a batch of images and labels by building up a queue of examples.
# 构造 batch_size 样本集(图像+标签)
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size, shuffle=True) # 5 -------------->
3 --------------> tensorflow/python/training/input.py
string_input_producer() 函数是将字符串(比如文件名)入队到一个队列中,并且添加该队的 QueueRunner 到当前图的 QUEUE_RUNNER collection 中。
其中有几个主要的参数:
num_epochs: 限制 string_tensor 中字符串入队的次数,如果没有定义的话,就是无限次将 string_tensor 中的字符串入队到队列中。
shuffle: 表示是否乱序,如果是 True, 表示字符串入队到队列中是以乱序的形式。
除了 string_input_producer() 之外还有两个函数,实现不同对象的入队操作
# 将 0 - (limit-1) 的整数入队到队列中
range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None)
# 将 tensor_list 中各 Tensor 的切片入队到队列中
slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None)
string_input_producer() 会调用 input_producer() 进行具体的操作。
def string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None, cancel_op=None):
with ops.name_scope(name, "input_producer", [string_tensor]) as name:
string_tensor = ops.convert_to_tensor(string_tensor, dtype=dtypes.string)
with ops.control_dependencies([control_flow_ops.Assert(
math_ops.greater(array_ops.size(string_tensor), 0), [not_null_err])]):
string_tensor = array_ops.identity(string_tensor)
return input_producer(input_tensor=string_tensor, element_shape=[], num_epochs=num_epochs,
shuffle=shuffle, seed=seed, capacity=capacity, shared_name=shared_name, name=name,
summary_name="fraction_of_%d_full" % capacity, cancel_op=cancel_op) # 3.1 -------------->
3.1 --------------> tensorflow/python/training/input.py
input_producer() 函数主要做了以下操作:
1. 根据参数 shuffle 和 num_epochs 确定 input_tensor,分别通过分支 3.1.1 和 3.1.2 进行具体的分析。
2. 创建队列及入队操作,分别通过分支 3.1.3 和 3.1.4 进行具体的分析。
3. 创建 QueueRunner, 并将其加入图的 QUEUE_RUNNER 集合中,通过分支 3.1.5 进行具体的分析。
注意:
queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op)
将队列 q 的操作列表 [enq] 添加到一个 QueueRunner,
这里的操作列表 [enq] 会影响后续训练过程中创建线程的个数。(QueueRunner.create_threads() 函数)
def input_producer(input_tensor, element_shape=None, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, summary_name=None, name=None, cancel_op=None):
with ops.name_scope(name, "input_producer", [input_tensor]):
input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
element_shape = input_tensor.get_shape()[1:].merge_with(element_shape)
# 是否乱序乱序
if shuffle:
input_tensor = random_ops.random_shuffle(input_tensor, seed=seed) # 3.1.1 -------------->
# 限制迭代次数
input_tensor = limit_epochs(input_tensor, num_epochs) # 3.1.2 -------------->
# 创建队列及入队操作
q = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=[input_tensor.dtype.base_dtype],
shapes=[element_shape], shared_name=shared_name, name=name) # 3.1.3 -------------->
enq = q.enqueue_many([input_tensor]) # 3.1.4 -------------->
# 创建 QueueRunner,并将其加入图的集合中
queue_runner.add_queue_runner(queue_runner.QueueRunner
(q, [enq], cancel_op=cancel_op)) # 3.1.5 -------------->
if summary_name is not None:
summary.scalar(summary_name, math_ops.cast(q.size(), dtypes.float32) * (1. / capacity))
return q
3.1.1 --------------> tensorflow/python/ops/rand