mnist代码解读

源码地址:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/tutorials/mnist

1. mnist.py 构建一个完全连接的mnist模型

导包:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import tensorflow as tf

mnist数据集有9类,表示数字0到9.:

NUM_CLASSES = 10

mnist的图像都是28*28像素的。:

IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE

模型骨架:推理 interence()

其中tf.name_scope()的用法可以参见https://blog.csdn.net/Dorothy_Xue/article/details/83903763中的第9条

def inference(images, hidden1_units, hidden2_units):
#实现方式:利用ReLU激活函数,构建两个完全连接层,以及softmax层。
#其中:图像100*784、隐层1:784*128==>100*128;隐层2:128*32==>100*32;
#softmax层:32*10==>100*10  【为100张图分别预测属于0到9每个数字的概率】
  """Build the MNIST model up to where it may be used for inference.

  Args:参数
    images: Images placeholder, from inputs(). 图像占位符
    hidden1_units: Size of the first hidden layer.隐层1的大小(节点个数):128
    hidden2_units: Size of the second hidden layer.隐层2的大小:32

  Returns:返回
    softmax_linear: Output tensor with the computed logits.
    返回logits,其中包含了预测的结果,也就是模型输出


  """
  # Hidden 1
  with tf.name_scope('hidden1'):#定义隐层1
    #定义变量 权重:([784,128],标准差,命名为“weights”)
    weights = tf.Variable(          
        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
        name='weights')
    #定义变量 偏置:([0...0]隐层1节点个数 个0,命名为“biases”)
    biases = tf.Variable(tf.zeros([hidden1_units]), 
                         name='biases')
    #隐层1的结果,用ReLU函数对images×weight+biases进行了激活
    hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
    


  # Hidden 2
  with tf.name_scope('hidden2'):
    weights = tf.Variable(
        tf.truncated_normal([hidden1_units, hidden2_units],
                            stddev=1.0 / math.sqrt(float(hidden1_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([hidden2_units]),
                         name='biases')
    hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)


  # Linear 
  with tf.name_scope('softmax_linear'):
    weights = tf.Variable(
        tf.truncated_normal([hidden2_units, NUM_CLASSES],
                            stddev=1.0 / math.sqrt(float(hidden2_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([NUM_CLASSES]),
                         name='biases')
    logits = tf.matmul(hidden2, weights) + biases


  return logits  #返回预测结果

计算输出和ground truth之间的损失【在这里对应的是logits和labels之间的损失】:

其中tf.to_int64()的用法可以参见https://blog.csdn.net/Dorothy_Xue/article/details/84975706中的第10条

tf.losses.sparse_softmax_cross_entropy()的用法可参见上面地址中的第11条

def loss(logits, labels):
  """Calculates the loss from the logits and the labels.

  Args: #参数:
    logits: Logits tensor, float - [batch_size, NUM_CLASSES].
    labels: Labels tensor, int32 - [batch_size].

  Returns:
    loss: Loss tensor of type float. #返回float类型的损失张量
  """


  labels = tf.to_int64(labels) #将labels转换成int64类型的张量
  #返回:算出labels与logits之间的softmax值,然后计算损失
  return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

最小化损失的的训练操作:

其中tf.summary.scalar()用法可参见https://blog.csdn.net/Dorothy_Xue/article/details/84975706中的第12条

tf.train.GradientDescentOptimizer()用法可参见上面地址的第13条

optimizer.minimize()用法可参见上面地址的第14条

def training(loss, learning_rate):
  """Sets up the training Ops.

  # 建立一个summarizer来记录损失,个人感觉像一个日志,便于之后利用TensorBoard可视化
  Creates a summarizer to track the loss over time in TensorBoard.
  # 建立一个应用于所有可训练参数的梯度的优化器
  Creates an optimizer and applies the gradients to all trainable variables.
  # 返回一个 使模型能够通过sess.run()训练 的操作 
  The Op returned by this function is what must be passed to the
  `sess.run()` call to cause the model to train.

  # 参数:
  Args:
    loss: Loss tensor, from loss(). #损失:由loss()函数输出的loss张量
    #给定的用于梯度下降的学习率
    learning_rate: The learning rate to use for gradient descent. 

  # 返回:
  Returns:
    train_op: The Op for training. #用于训练的操作
  """


  #添加一个记录snapshpt损失的的summary
  # Add a scalar summary for the snapshot loss.
  tf.summary.scalar('loss', loss)

  #利用给定的学习率建立一个梯度下降优化器
  # Create the gradient descent optimizer with the given learning rate.
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)

  # 声明一个用于记录全局步骤进行到哪一步了
  # Create a variable to track the global step.
  global_step = tf.Variable(0, name='global_step', trainable=False)

  # 应用梯度下降算法最小化损失
  # Use the optimizer to apply the gradients that minimize the loss
  # (and also increment the global step counter) as a single training step.
  train_op = optimizer.minimize(loss, global_step=global_step)

  # 返回包含以上操作的用于训练的 操作
  return train_op

评价预测结果准确性的函数:

其中tf.nn.in_top_k()的用法可参见https://blog.csdn.net/Dorothy_Xue/article/details/84975706中的第15条

tf.cast()的用法可参见上面地址中的第16条

tf.reduce_sum()的用法可参见上面地址中的第17条

def evaluation(logits, labels):
  """Evaluate the quality of the logits at predicting the label.
  #用于评价模型输出结果的质量

  #参数:
  Args:
    logits: Logits tensor, float - [batch_size, NUM_CLASSES].
    labels: Labels tensor, int32 - [batch_size], with values in the
      range [0, NUM_CLASSES).

  #返回:
  Returns:
    #返回一个张量,记录了batch_size大小的样例中,正确预测的结果的数量
    A scalar int32 tensor with the number of examples (out of batch_size)
    that were predicted correctly.
  """
  # For a classifier model, we can use the in_top_k Op.
  # It returns a bool tensor with shape [batch_size] that is true for
  # the examples where the label is in the top k (here k=1)
  # of all logits for that example.

  # 返回一个labels长度的布尔型向量,代表的是labels中的值为索引,在logits中对应的数,
  # 是否在logits最大的前1个数中
  # 这一步就是在将labels与logits一一对比,看labels中值为1的索引对应的logits值是不是
  # logits中最大的那个,就是看预测的准不准,准的话就返回true,否则返回false
  correct = tf.nn.in_top_k(logits, labels, 1)

  # 返回值为true的个数
  # Return the number of true entries.
  return tf.reduce_sum(tf.cast(correct, tf.int32))

2. fully_connected_feed.py:利用下载的数据集训练构建好的MNIST模型, 以数据反馈字典(feed dictionary)的形式作为输入

导包:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=missing-docstring
import argparse
import os
import sys
import time

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist

作为外部标记的基本模型参数:

# Basic model parameters as external flags.
FLAGS = None

为输入的张量生成占位符,以表示该张量【感觉像int i;i=0;中的int i这一步】:

def placeholder_inputs(batch_size):
  """Generate placeholder variables to represent the input tensors.
  #为输入的张量生成占位符,用来代表该张量
  These placeholders are used as inputs by the rest of the model building
  code and will be fed from the downloaded data in the .run() loop, below.

  Args:
    batch_size: The batch size will be baked into both placeholders.

  Returns:
    images_placeholder: Images placeholder.
    labels_placeholder: Labels placeholder.
  """
  # Note that the shapes of the placeholders match the shapes of the full
  # image and label tensors, except the first dimension is now batch_size
  # rather than the full size of the train or test data sets.
  # 注意,占位符的形状与完整图像和标签张量的形状相匹配,但第一个维度现在是batch_size,而不是训练或
    测试数据集的完整大小。
  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                         mnist.IMAGE_PIXELS))
  labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
  return images_placeholder, labels_placeholder

为该训练step中的变量赋值,送进去一个字典,给上面的占位符赋值了【感觉像int i;i=0;中的i=0这一步】:

其中feed_dict()的用法可参见:https://blog.csdn.net/Dorothy_Xue/article/details/83903763中第8条

X.next_batch()的用法可参见https://blog.csdn.net/Dorothy_Xue/article/details/84975706中的第18条

def fill_feed_dict(data_set, images_pl, labels_pl):
  """Fills the feed_dict for training the given step.

  A feed_dict takes the form of:
  feed_dict = {
      <placeholder>: <tensor of values to be passed for placeholder>,
      ....
  }

  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().

  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size` examples.
  images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                 FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict

对整个迭代过程进行评估:

其中:xrange()的用法可参见:https://blog.csdn.net/Dorothy_Xue/article/details/84975706中的第19条

def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
  """Runs one evaluation against the full epoch of data.

  Args:
    sess: The session in which the model has been trained.
    eval_correct: The Tensor that returns the number of correct predictions.
    images_placeholder: The images placeholder.
    labels_placeholder: The labels placeholder.
    data_set: The set of images and labels to evaluate, from
      input_data.read_data_sets().
  """
  # And run one epoch of eval.
  true_count = 0  # Counts the number of correct predictions.
  steps_per_epoch = data_set.num_examples // FLAGS.batch_size
  num_examples = steps_per_epoch * FLAGS.batch_size
  for step in xrange(steps_per_epoch):
    feed_dict = fill_feed_dict(data_set,
                               images_placeholder,
                               labels_placeholder)
    true_count += sess.run(eval_correct, feed_dict=feed_dict)
  precision = float(true_count) / num_examples
  print('Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
        (num_examples, true_count, precision))

下面是核心函数!!!

功能:经过一些列步骤训练mnist

其中tf.train.Saver()的用法可参见:https://blog.csdn.net/Dorothy_Xue/article/details/84979049中的第20条

tf.summary.FileWriter()的用法可参见上面地址中的第22条

time time()的用法可参见上面地址中的第23条

def run_training():
  """Train MNIST for a number of steps."""
  # Get the sets of images and labels for training, validation, and
  # test on MNIST.

  # 获取数据集
  data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)

  # Tell TensorFlow that the model will be built into the default Graph.
  # 在默认图上建立模型
  with tf.Graph().as_default():
    # Generate placeholders for the images and labels.
    # 为图像和labels创建占位符
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)

    # Build a Graph that computes predictions from the inference model.
    # 定义一个 传入参数,利用inference得到logits  的operation
    logits = mnist.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # Add to the Graph the Ops for loss calculation.
    # 定义一个 传入参数,利用loss得到损失  的operation
    loss = mnist.loss(logits, labels_placeholder)

    # Add to the Graph the Ops that calculate and apply gradients.
    # 定义一个 传入参数,最小化损失 的operation
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # Add the Op to compare the logits to the labels during evaluation.
    # 定义一个 传入参数,评价输出结果质量【预测正确的个数】 的operation
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    # Build the summary Tensor based on the TF collection of Summaries.
    # 定义一个summary,将所有summary全部保存到磁盘,以便tensorboard显示
    summary = tf.summary.merge_all()

    # Add the variable initializer Op.
    # 定义一个 初始化变量 的operation
    init = tf.global_variables_initializer()

    # Create a saver for writing training checkpoints.
    # 建立一个saver保存training checkpoints
    saver = tf.train.Saver()

    # Create a session for running Ops on the Graph.
    # 创建会话
    sess = tf.Session()

    # Instantiate a SummaryWriter to output summaries and the Graph.
    # 实例化SummaryWriter来输出summaries和图
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

    # And then after everything is built:
    # 定义完所有operation 之后

    # Run the Op to initialize the variables.
    # 运行这个函数初始化变量
    sess.run(init)

    # Start the training loop.
    # 开始循环训练
    for step in xrange(FLAGS.max_steps):
      start_time = time.time() #记录当前时间点

      # Fill a feed dictionary with the actual set of images and labels
      # for this particular training step.
      # 为每个训练步(step)提供一个反馈字典【送进数据和标签集占位符,得到当前批次的数据和labels】
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

      # Run one step of the model.  The return values are the activations
      # from the `train_op` (which is discarded) and the `loss` Op.  To
      # inspect the values of your Ops or variables, you may include them
      # in the list passed to sess.run() and the value tensors will be
      # returned in the tuple from the call.
      # 利用train_op和loss两个operation和上面获取的数据集,计算损失
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

      # 计算用时
      duration = time.time() - start_time

      # Write the summaries and print an overview fairly often.
      # 每100个训练步输出一次当前状态
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # Update the events file.
        # 更新事件文件
        summary_str = sess.run(summary, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

      # Save a checkpoint and evaluate the model periodically.
      # 保存检查点并没1000个训练步评估一次模型
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)
        # Evaluate against the training set.
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)
        # Evaluate against the validation set.
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)
        # Evaluate against the test set.
        print('Test Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)

主函数:

其中:tf.gfile.Exists()的用法可参见https://blog.csdn.net/Dorothy_Xue/article/details/84979049中的第21条

def main(_):
  if tf.gfile.Exists(FLAGS.log_dir):
    tf.gfile.DeleteRecursively(FLAGS.log_dir)
  tf.gfile.MakeDirs(FLAGS.log_dir)
  run_training()

定义参数:

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--learning_rate',
      type=float,
      default=0.01,
      help='Initial learning rate.'
  )
  parser.add_argument(
      '--max_steps',
      type=int,
      default=2000,
      help='Number of steps to run trainer.'
  )
  parser.add_argument(
      '--hidden1',
      type=int,
      default=128,
      help='Number of units in hidden layer 1.'
  )
  parser.add_argument(
      '--hidden2',
      type=int,
      default=32,
      help='Number of units in hidden layer 2.'
  )
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='Batch size.  Must divide evenly into the dataset sizes.'
  )
  parser.add_argument(
      '--input_data_dir',
      type=str,
      default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
                           'tensorflow/mnist/input_data'),
      help='Directory to put the input data.'
  )
  parser.add_argument(
      '--log_dir',
      type=str,
      default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
                           'tensorflow/mnist/logs/fully_connected_feed'),
      help='Directory to put the log data.'
  )
  parser.add_argument(
      '--fake_data',
      default=False,
      help='If true, uses fake data for unit testing.',
      action='store_true'
  )

定义FLAGS,以及启动运行:

其中:tf.spp.run()的用法可参见https://blog.csdn.net/Dorothy_Xue/article/details/84979049中的第24条

FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值