TensorFlow数据读取模块调用过程(cifar10)

最近在看TensorFlow数据读取模块,有了一点思路,先把读取部分的调用过程写下来,以cifar10为例。


入口 cifar10_train.py

distorted_inputs() 函数执行数据读取

def train():
  with tf.Graph().as_default():
	......
    # Get images and labels for CIFAR-10.
    # 从二进制文件中读取数据 images, labels
    images, labels = cifar10.distorted_inputs()  # 1 -------------->
	......

1 --------------> cifar10.py

cifar10.distorted_inputs() 是将 data_dir 定义后再调用 cifar10_input.distorted_inputs()
cifar10_input.distorted_inputs 是执行数据读取的主要函数

def distorted_inputs():
  if not FLAGS.data_dir:
    raise ValueError('Please supply a data_dir')
  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
  images, labels = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=FLAGS.batch_size)  # 2 -------------->
  if FLAGS.use_fp16:
    images = tf.cast(images, tf.float16)
    labels = tf.cast(labels, tf.float16)
  return images, labels	

2 --------------> cifar10_input.py 

cifar10_input.distorted_inputs() 主要有几部分组成

1. 生成文件名队列。使用 tf.train.string_input_producer()函数生成文件名队列,通过调用分支 3 进行具体的调用过程分析。

2. 文件读取与解析。通过在函数 read_cifar10() 中定义了对应文件类型的阅读器及解析器,并通过对应的 read 及 decode 方法得到样本数据,通过调用分支 4 进行具体的调用过程分析。

3. 样本处理(包括裁剪、翻转等)

4. 样本批处理。通过在函数 _generate_image_and_label_batch() 中设置线程数,并调用不同的批处理函数,进行数据的批处理,通过调用分支 5 进行具体的调用过程分析。 

===================================================================================================

def distorted_inputs(data_dir, batch_size):
  filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)]
  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)

  # Create a queue that produces the filenames to read.
  # 生成要读取的文件名队列
  filename_queue = tf.train.string_input_producer(filenames)  # 3 -------------->

  # Read examples from files in the filename queue.
  read_input = read_cifar10(filename_queue) # 4 -------------->
  reshaped_image = tf.cast(read_input.uint8image, tf.float32)

  height = IMAGE_SIZE
  width = IMAGE_SIZE

  # Image processing for training the network. Note the many random
  # distortions applied to the image.
  # 为训练网络进行图像处理。注意应用于图像的许多随机失真。
  
  # Randomly crop a [height, width] section of the image.
  # 随机裁剪图像为 [height,width] 像素大小的图片
  distorted_image = tf.random_crop(reshaped_image, [height, width, 3])

  # Randomly flip the image horizontally.
  # 随意地水平翻转图像。
  distorted_image = tf.image.random_flip_left_right(distorted_image)

  # Because these operations are not commutative, consider randomizing
  # the order their operation.
  # 因为这些操作是不可交换的,所以请考虑将它们的操作随机化。
  # 随机的改变图片的亮度
  distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
                                              
  # 随机的改变图片的对比度                                             
  distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)

  # Subtract off the mean and divide by the variance of the pixels.
  # 图像的白化:减去平均值并除以像素的方差,均值与方差的均衡,降低图像明暗、光照差异引起的影响
  float_image = tf.image.per_image_standardization(distorted_image)

  # Set the shapes of tensors.
  float_image.set_shape([height, width, 3])
  read_input.label.set_shape([1])

  # Ensure that the random shuffling has good mixing properties.
  # 确保随机 shuffling 具有良好的混合性能。
  min_fraction_of_examples_in_queue = 0.4
  min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                           min_fraction_of_examples_in_queue)
  print ('Filling queue with %d CIFAR images before starting to train. '
         'This will take a few minutes.' % min_queue_examples)

  # Generate a batch of images and labels by building up a queue of examples.
  # 构造 batch_size 样本集(图像+标签)
  return _generate_image_and_label_batch(float_image, read_input.label,
                             min_queue_examples, batch_size, shuffle=True) # 5 -------------->


3 --------------> tensorflow/python/training/input.py

string_input_producer() 函数是将字符串(比如文件名)入队到一个队列中,并且添加该队的 QueueRunner 到当前图的 QUEUE_RUNNER collection 中。
其中有几个主要的参数:
num_epochs: 限制 string_tensor 中字符串入队的次数,如果没有定义的话,就是无限次将 string_tensor 中的字符串入队到队列中。
shuffle: 表示是否乱序,如果是 True, 表示字符串入队到队列中是以乱序的形式。

除了 string_input_producer() 之外还有两个函数,实现不同对象的入队操作
# 将 0 - (limit-1) 的整数入队到队列中
range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
                         capacity=32, shared_name=None, name=None)
# 将 tensor_list 中各 Tensor 的切片入队到队列中  
slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
                         capacity=32, shared_name=None, name=None)

string_input_producer() 会调用 input_producer() 进行具体的操作。


def string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None,
                         capacity=32, shared_name=None, name=None, cancel_op=None):	                        
  with ops.name_scope(name, "input_producer", [string_tensor]) as name:
    string_tensor = ops.convert_to_tensor(string_tensor, dtype=dtypes.string)
    with ops.control_dependencies([control_flow_ops.Assert(
            math_ops.greater(array_ops.size(string_tensor), 0), [not_null_err])]):
      string_tensor = array_ops.identity(string_tensor)
	  
    return input_producer(input_tensor=string_tensor, element_shape=[], num_epochs=num_epochs,
        shuffle=shuffle, seed=seed, capacity=capacity, shared_name=shared_name, name=name,
         summary_name="fraction_of_%d_full" % capacity, cancel_op=cancel_op)  # 3.1 -------------->

3.1 --------------> tensorflow/python/training/input.py

input_producer() 函数主要做了以下操作:
1. 根据参数 shuffle 和 num_epochs 确定 input_tensor,分别通过分支 3.1.1 和 3.1.2 进行具体的分析。 
2. 创建队列及入队操作,分别通过分支 3.1.3 和 3.1.4 进行具体的分析。 
3. 创建 QueueRunner, 并将其加入图的 QUEUE_RUNNER 集合中,通过分支 3.1.5 进行具体的分析。

注意:
queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op)
将队列 q 的操作列表 [enq] 添加到一个 QueueRunner, 
这里的操作列表 [enq] 会影响后续训练过程中创建线程的个数。(QueueRunner.create_threads() 函数)

def input_producer(input_tensor, element_shape=None, num_epochs=None, shuffle=True, seed=None,
                   capacity=32, shared_name=None, summary_name=None, name=None, cancel_op=None):
  with ops.name_scope(name, "input_producer", [input_tensor]):
    input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
    element_shape = input_tensor.get_shape()[1:].merge_with(element_shape)
					  
    # 是否乱序乱序
    if shuffle:
      input_tensor = random_ops.random_shuffle(input_tensor, seed=seed) # 3.1.1 -------------->
	  
    # 限制迭代次数
    input_tensor = limit_epochs(input_tensor, num_epochs) # 3.1.2 -------------->
	
    # 创建队列及入队操作
    q = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=[input_tensor.dtype.base_dtype], 
            shapes=[element_shape], shared_name=shared_name, name=name) # 3.1.3 -------------->                              							
    enq = q.enqueue_many([input_tensor]) # 3.1.4 --------------> 
	
    # 创建 QueueRunner,并将其加入图的集合中
    queue_runner.add_queue_runner(queue_runner.QueueRunner
	                             (q, [enq], cancel_op=cancel_op))  # 3.1.5 -------------->
	
    if summary_name is not None:
      summary.scalar(summary_name, math_ops.cast(q.size(), dtypes.float32) * (1. / capacity))
    return q       
        	

3.1.1 --------------> tensorflow/python/ops/rand

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值