TensorFlow MNIST手写数字识别学习笔记(一)

TensorFlow MNIST手写数字识别学习笔记(一)

MNIST手写数字识别模型建立

MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.
在这里插入图片描述

源码下载

源码地址:代码下载地址

代码包含:

文件目的
fully_connected_feed.py网络配置,TF设置以及启动网络
input_data.py下载用于训练和测试的MNIST数据集的源码
mnist.p y网络结构定义

下面我们具体解析一下fully_connected_feed.py这个文件
定义 placeholder_inputs函数,为读取的图片和标签预留占位

def placeholder_inputs(batch_size):
"""为读取的数据集图片占位,预留float32位,shape为batch_size为100(程序后面定义了),IMAGE_PIXELS为28*28(程序后面定义了),因为图片为28*28像素"""
      images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                      mnist.IMAGE_PIXELS)
"""为读取的标签占位,预留预留float32位,shape为batch_size为100"""
      labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
      return images_placeholder, labels_placeholder

定义fill_feed_dict函数,填充读取的图片和标签

def fill_feed_dict(data_set, images_pl, labels_pl):
"""把data_set读取来的图片填充到images_feed,fake_data标记是用于单元测试的,可以不必理会。"""
    images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                 FLAGS.fake_data)
"""把得到的images_feed, labels_feed分别填充到images_pl, labels_pl"""
    feed_dict = {
        images_pl: images_feed,
        labels_pl: labels_feed,
  }
    return feed_dict

定义do_eval函数,计算几轮,正确值,与预测值

def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
 """计算正确的预测值,默认为0"""
   true_count = 0 
 """对从网上读取的num_examples进行地板除,结果向下取整"""
   steps_per_epoch = data_set.num_examples // FLAGS.batch_size
   num_examples = steps_per_epoch * FLAGS.batch_size
 """建立一个循环,把数据循环一遍"""
   for step in xrange(steps_per_epoch):
   """每次把之前预留好的占位用真实数据填充一遍"""
     feed_dict = fill_feed_dict(data_set,
                               images_placeholder,
                               labels_placeholder)
   """叠加计算正确的次数 """
     true_count += sess.run(eval_correct, feed_dict=feed_dict)
   """计算预测值"""
   precision = float(true_count) / num_examples
   print('Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
         (num_examples, true_count, precision))

定义run_training函数,对数据进行训练

def run_training():
 """
  data_sets = input_data.read_data_sets(FLAGS.input_data_dir,
              FLAGS.fake_data)
   """定义了一个图:Graph """
  with tf.Graph().as_default():
    # Generate placeholders for the images and labels.
   """为images_placeholder, labels_placeholder预留FLAGS.batch_size大小的占位 """
    images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)
    """建立一个有两层隐藏层的网络结构 """
    logits = mnist.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)
     """为网络添加计算损失的loss OP"""
    loss = mnist.loss(logits, labels_placeholder)
    """ 为网络添加计算斜率的gradients Op"""
    train_op = mnist.training(loss, FLAGS.learning_rate)
    """为网络添加evaluation OP,计算正确数量"""
    eval_correct = mnist.evaluation(logits, labels_placeholder)
    """建立summary Tensor,merge_all可以将所有summary全部保存到磁盘,以便tensorboard显示。"""
    summary = tf.summary.merge_all()
    """ 添加初始化OP"""
    init = tf.global_variables_initializer()
    """ 创建一个tf.train.Saver() 类保存所有变量和网络结构体到文件中""" 
    saver = tf.train.Saver()
    """ 创建一个在图中session OP."""
    sess = tf.compat.v1.Session()
    """ 指定一个文件用来保存图。格式:tf.summary.FileWritter(path,sess.graph)可以调用其add_summary()方法将训练过程数据保存在filewriter指定的文件中.""" 
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
    # And then after everything is built:
    """开始session OP运行 """ 
    sess.run(init)
    """开始循环训练 """
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      """ 把文件填充到具体的位置"""
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)
      """_, 是忽略第一个返回的参数的意思,计算损失函数"""
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)
      """计算运算的时间"""
      duration = time.time() - start_time
      """打印出计算次数和损失值 """
      if step % 100 == 0:
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))

        """更新log文件."""
        summary_str = sess.run(summary, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

      """把计算好的数据保存到checkpoint文件."""
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)

        """计算训练集的正确值,与预测值."""

        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)

         """计算验证集的正确值,与预测值."""
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)

        """计算测试集的正确值,与预测值."""
        print('Test Data Eval:')

        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)

定义main函数,创建FLAGS.log_dir文件

def main(_):
    """判断目录或文件是否存在,FLAGS.log_dir可为目录路径或带文件名的路径,有该目录则返回True,否则False。"""
  if tf.gfile.Exists(FLAGS.log_dir):
    """递归删除所有目录及其文件,FLAGS.log_dir即目录名,无返回。"""
    tf.gfile.DeleteRecursively(FLAGS.log_dir)
   """以递归方式建立父目录及其子目录,如果目录已存在且是可覆盖则会创建成功,否则报错,无返回。"""
  tf.gfile.MakeDirs(FLAGS.log_dir)

  run_training()

"""程序启动时的一些参数配置"""
if __name__ == '__main__':

  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--learning_rate',
      type=float,
      default=0.01,
      help='Initial learning rate.'
  )
  parser.add_argument(
      '--max_steps',
      type=int,
      default=2000,
      help='Number of steps to run trainer.'
  )
  parser.add_argument(
      '--hidden1',
      type=int,
      default=128,
      help='Number of units in hidden layer 1.'
  )
  parser.add_argument(
      '--hidden2',
      type=int,
      default=32,
      help='Number of units in hidden layer 2.'
  )
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='Batch size.  Must divide evenly into the dataset sizes.'
  )

  parser.add_argument(
      '--input_data_dir',
      type=str,
      default=os.path.join(os.getcwd(),
                           'input_data'),
      help='Directory to put the input data.'
  )

  parser.add_argument(
      '--log_dir',
      type=str,
      default=os.path.join(os.getcwd(),
                           'logs/fully_connected_feed'),
      help='Directory to put the log data.'
  )
  parser.add_argument(
      '--fake_data',
      default=False,
      help='If true, uses fake data for unit testing.',
      action='store_true'
  )
"""使用配置好的参数启动程序"""
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值