tensorflow cifar_10 代码阅读与理解

前言


Tensorflow 提供cifar_10 benchmark问题的示例代码,并且在中文翻译的官方文档中有专门的一章介绍该卷积神经网络(CNN),作为刚刚接触深度学习与Tensorflow框架的菜鸟,对tf提供的大量库函数与深度学习的trick并不十分熟悉,因此花了两天的时间通读懂了代码,下面具体剖析一下整个程序的过程,作为学习记录。

准备工作

从Github https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10下载到的cifar10程序共包括
1. cifar10.py
2. cifar10_eval.py
3. cifar10_input.py
4. cifar10_multi_gpu_train.py
5. cifar10_train.py
本文只讨论CPU版本,因此可以自动忽略cifar10_multi_gpu_train.py文件。另外,为加快程序调试,避免在程序运行时再去自动下载资源,可以提前去 http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz下好图片压缩文件,放在默认位置/tmp/cifar10_data 中即可

代码阅读

总体流程

程序入口为cifar10_train中的main方法,其代码其注释如下:
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():    # use the default graph in the process in the context
    global_step = tf.contrib.framework.get_or_create_global_step()  # Returns and create (if necessary) the global step variable. However the method is depressed in V0.8.0
    #global_step = tf.Variable(0, name='global_step', trainable=False)
    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1

      def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        loss_value = run_values.results
        if self._step % 10 == 0:
          num_examples_per_step = FLAGS.batch_size
          examples_per_sec = num_examples_per_step / duration
          sec_per_batch = float(duration)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    #For a chief, this utility sets proper session initializer/restorer. It also creates hooks related to checkpoint and summary saving. For workers, this utility sets proper session creator which waits for the chief to inialize/restore.
    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)



其中,专门定义_LoggerHook类,在mon_sess这个对话中注册,代码中最后一句,表示在停止条件达到之前,循环运行train_op,更新网络系数

读取文件

调用cifar10.py中的 distorted_inputs方法,其主要语句是

filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
               for i in xrange(1, 6)]
  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)

  # Create a queue that produces the filenames to read.
  filename_queue = tf.train.string_input_producer(filenames)

  # Pass the list of filenames to the tf.train.string_input_producer function. string_input_producer creates a FIFO queue for holding the filenames until the reader needs them.

  # Read examples from files in the filename queue.
  read_input = read_cifar10(filename_queue)
  reshaped_image = tf.cast(read_input.uint8image, tf.float32)

  height = IMAGE_SIZE
  width = IMAGE_SIZE

  # Image processing for training the network. Note the many random
  # distortions applied to the image.

  # Randomly crop a [height, width] section of the image.
  distorted_image = tf.random_crop(reshaped_image, [height, width, 3])

  # Randomly flip the image horizontally.
  distorted_image = tf.image.random_flip_left_right(distorted_image)

  # Because these operations are not commutative, consider randomizing
  # the order their operation.
  distorted_image = tf.image.random_brightness(distorted_image,
                                               max_delta=63)
  distorted_image = tf.image.random_contrast(distorted_image,
                                             lower=0.2, upper=1.8)

  # Subtract off the mean and divide by the variance of the pixels.
  float_image = tf.image.per_image_standardization(distorted_image)

  #Linearly scales image to have zero mean and unit norm.

  # Set the shapes of tensors.
  float_image.set_shape([height, width, 3])
  read_input.label.set_shape([1])

  # Ensure that the random shuffling has good mixing properties.
  min_fraction_of_examples_in_queue = 0.4
  min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                           min_fraction_of_examples_in_queue)

  #tf.train.shuffle_batch equals 20000.

  print ('Filling queue with %d CIFAR images before starting to train. '
         'This will take a few minutes.' % min_queue_examples)

  # Generate a batch of images and labels by building up a queue of examples.
  return _generate_image_and_label_batch(float_image, read_input.label,
                                         min_queue_examples, batch_size,
                                         shuffle=True)


  # 采用多线程并行读入样本,构成一个训练batch,大小为128,需要注意的是,这里返回的是一个batch,images张量形式为(128,32,32,3)

  return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size, shuffle=True)


其中,read_cifar10方法内容如下:
class CIFAR10Record(object):
    pass
  result = CIFAR10Record()

  # Dimensions of the images in the CIFAR-10 dataset.
  # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
  # input format.
  label_bytes = 1  # 2 for CIFAR-100
  result.height = 32
  result.width = 32
  result.depth = 3
  image_bytes = result.height * result.width * result.depth
  # Every record consists of a label followed by the image, with a
  # fixed number of bytes for each.
  record_bytes = label_bytes + image_bytes

  # Read a record, getting filenames from the filename_queue.  No
  # header or footer in the CIFAR-10 format, so we leave header_bytes
  # and footer_bytes at their default of 0.

  # 下面的FixedLengthRecordReader与reader.read,tf.decode_raw配合起来,是以固定长度读取文件名队列中数据的一个常用方法
  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
  result.key, value = reader.read(filename_queue)

  # Convert from a string to a vector of uint8 that is record_bytes long.
  record_bytes = tf.decode_raw(value, tf.uint8)

  # To read binary files in which each record is a fixed number of bytes, use tf.FixedLengthRecordReader with the tf.decode_raw operation. The decode_raw op converts from a string to a uint8 tensor.
  # For example, the CIFAR-10 dataset uses a file format where each record is represented using a fixed number of bytes: 1 byte for the label followed by 3072 bytes of image data. Once you have a uint8 tensor, standard operations can slice out each piece and reformat as needed. For CIFAR-10, you can see how to do the reading and decoding in
  # The first bytes represent the label, which we convert from uint8->int32.
  # 下面采用tf.strided_slice方法在record_bytes中提取第一个bytes作为标签
  result.label = tf.cast(
      tf.strided_slice(record_bytes, [0], [label_bytes], [1]), tf.int32) #unfortunately, the method "tf.strided_slice" is deprecated  in this version, What can be subsititued?

  # strided_slice Extracts a strided slice from a tensor.
  # The remaining bytes after the label represent the image, which we reshape
  # from [depth * height * width] to [depth, height, width].
 

  # 下面采用tf.strided_slice方法在record_bytes中的图片数据信息

  depth_major = tf.reshape( tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes], [1]), [result.depth, result.height, result.width]) # Convert from [depth, height, width] to [height, width, depth]. result.uint8image = tf.transpose(depth_major, [1, 2, 0]) return result
如代码所述,read_cifar10其实返回了一个训练样本,包括result.label 和result.uint8image两个数据成员。其中,_generate_image_and_label_batch方法内容如下:
num_preprocess_threads = 16
  if shuffle:
  # 随机产生一个batch,有16个线程,而读入的缓存大小为20000,capacity为20000+3*128
   images, label_batch = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size,
        min_after_dequeue=min_queue_examples)
  else:
    images, label_batch = tf.train.batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size)

  # Display the training images in the visualizer.
  tf.summary.image('images', images)

  return images, tf.reshape(label_batch, [batch_size])

 我认为载入的难以理解点就是明明是一个一个样本读取,最后却能返回一个完成的batch,并且还有一个载入的缓存,大小为20000,极客学院网站上解释的图

首先,我们先创建数据流图,这个数据流图由一些流水线的阶段组成,阶段间用队列连接在一起。第一阶段将生成文件名,我们读取这些文件名并且把他们排到文件名队列中。第二阶段从文件中读取数据(使用Reader),产生样本,而且把样本放在一个样本队列中。根据你的设置,实际上也可以拷贝第二阶段的样本,使得他们相互独立,这样就可以从多个文件中并行读取。在第二阶段的最后是一个排队操作,就是入队到队列中去,在下一阶段出队。因为我们是要开始运行这些入队操作的线程,所以我们的训练循环会使得样本队列中的样本不断地出队。
我的理解是FixedLengthRecordReader,reader.read读取的是一个文件名列表中任意一个文件中的一个样本信息,对应上图中dequeue,在这之后可以针对这一个样本进行处理,而最终的tf.train.shuffle_batch,则将16个不同的reader读到的样本组成batch并返回。这些方法必须配套使用,即虽然没有显式的将多线程及batch构成过程编程实现,但tensorflow帮我们实现了上述的机制。

模型定义

反而模型定义部分无需多讲,只要注意以下几点:
1. 虽然与MNIST相比,这里的深度学习网络也是采用了两个卷积层和两个池化层,但中间加入局部响应正则化层和两个全连接层
2. 定义两个全连接层时,通过_variable_with_weight_decay方法将权重的二范数值引入最终的loss计算,其相关代码如下所示:

weights = _variable_with_weight_decay('weights', shape=[384, 192],
                                          stddev=0.04, wd=0.004)
def _variable_with_weight_decay(name, shape, stddev, wd):
  """Helper to create an initialized Variable with weight decay.
  Note that the Variable is initialized with a truncated normal distribution.
  A weight decay is added only if one is specified.
  Args:
    name: name of the variable
    shape: list of ints
    stddev: standard deviation of a truncated Gaussian
    wd: add L2Loss weight decay multiplied by this float. If None, weight
        decay is not added for this Variable.
  Returns:
    Variable Tensor
  """
  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
  var = _variable_on_cpu(
      name,
      shape,
      tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
  if wd is not None:
    weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
    #Computes half the L2 norm of a tensor without the sqrt
    #output = sum(t ** 2) / 2
    tf.add_to_collection('losses', weight_decay)
  return var

其实,tf.nn.l2_loss就是将所有weights平方和除以2之后,然后weight_decay计算是乘以0.004系数加到losses计算中

训练目标

loss计算方法定义如下:
def loss(logits, labels):
  """Add L2Loss to all the trainable variables.
  Add summary for "Loss" and "Loss/avg".
  Args:
    logits: Logits from inference().
    labels: Labels from distorted_inputs or inputs(). 1-D tensor
            of shape [batch_size]
  Returns:
    Loss tensor of type float.
  """
  # Calculate the average cross entropy loss across the batch.
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits, labels, name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  # Wrapper for Graph.add_to_collection() using the default graph.
  # The total loss is defined as the cross entropy loss plus all of the weight
  # decay terms (L2 loss).
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

  #It seems that the cross_entropy_mean is add_to_collection and added

通过collection中'losses'字段,最后的tf.add_n将通常的熵值与上面所说的weights的二范数值相加作为loss

迭代过程

def train(total_loss, global_step):
  """Train CIFAR-10 model.
  Create an optimizer and apply to all trainable variables. Add moving
  average for all trainable variables.
  Args:
    total_loss: Total loss from loss().
    global_step: Integer Variable counting the number of training steps
      processed.
  Returns:
    train_op: op for training.
  """
  # Variables that affect learning rate.
  num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size  #50000/128
  decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)

  # Decay the learning rate exponentially based on the number of steps.
  lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                  global_step,
                                  decay_steps,
                                  LEARNING_RATE_DECAY_FACTOR,
                                  staircase=True)

  #When training a model, it is often recommended to lower the learning rate as the training progresses. This function applies an exponential decay function to a provided initial learning rate. It requires a global_step value to compute the decayed learning rate. You can just pass a TensorFlow variable that you increment at each training step.
  tf.summary.scalar('learning_rate', lr)

  # Generate moving averages of all losses and associated summaries.
  loss_averages_op = _add_loss_summaries(total_loss)

  # Compute gradients.
  with tf.control_dependencies([loss_averages_op]):
  # Returns a context manager that specifies control dependencies.
  # Use with the with keyword to specify that all operations constructed within the context should have control dependencies on control_inputs. For example:
    opt = tf.train.GradientDescentOptimizer(lr)
    grads = opt.compute_gradients(total_loss)

  # Apply gradients.
  apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

  # Add histograms for trainable variables.
  for var in tf.trainable_variables():
    tf.summary.histogram(var.op.name, var)

  # Add histograms for gradients.
  for grad, var in grads:
    if grad is not None:
      tf.summary.histogram(var.op.name + '/gradients', grad)

  # Track the moving averages of all trainable variables.
  variable_averages = tf.train.ExponentialMovingAverage(
      MOVING_AVERAGE_DECAY, global_step)
  variables_averages_op = variable_averages.apply(tf.trainable_variables())

  with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
    train_op = tf.no_op(name='train')
  #Does nothing. Only useful as a placeholder for control edges.
  return train_op

训练过程中主要的不同之处是

lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                  global_step,
                                  decay_steps,
                                  LEARNING_RATE_DECAY_FACTOR,
                                  staircase=True)
就是说在每一次迭代过程中,都需要重新计算一次learning rate,而这里初始的INITIAL_LEARNING_RATE为0.1,global_step为当前的迭代次数,decay_steps就是每多少代,learning_rate衰减到到LEARNING_RATE_DECAY_FACTOR×INITIAL_LEARNING_RATE值,比如本程序中LEARNING_RATE_DECAY_FACTOR = 0.1 ,而decay_steps = num_batches_per_epoch * NUM_EPOCHS_PER_DECAY = 50000/128×350 ,也就说每十多万次迭代,lr衰减为原来0.1,然后根据每代的lr,用梯度法计算

  opt = tf.train.GradientDescentOptimizer(lr)
   grads = opt.compute_gradients(total_loss)
 而后面的
 variable_averages = tf.train.ExponentialMovingAverage(
      MOVING_AVERAGE_DECAY, global_step)
产生一个滑动平均计算对象,MOVING_AVERAGE_DECAY = 0.999,则每一代中的decay值更新如下
           min(decay, (1 + num_updates) / (10 + num_updates))
采用这个计算得到的decay值对上面梯度法更新得到的所有参数进行平滑处理如下:
            shadow_variable = decay * shadow_variable + (1 - decay) * variable

PS

在实际运行过程中,本人采用0.12r版本,出现版本不兼容导致的错误和警告,主要有以下两个问题:
1.cifar10_input中的tf.strided_slice方法原程序中只提供3个参数,而在0.12r中的版本需要提供4个参数,相差一个步长stride的值,这里补充为[1]即可
2.写日志文件大量采用了tf.contrib.deprecated库中的方法,已全部失效, 可直接采用tf.summary.scalar等方法cifar10_train.pycifar10_train

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值