Python3实现TensorFlow运作方式入门

TensorFlow运作方式入门中文文档, 这里是直译过来的,所以很多逻辑顺序不是很合理,刚开始看的时候一脸懵,需要整体开下来,多看几遍,然后才能理解一点,不过由于很多代码是基于python2实现的,换成python3实现起来对于刚入门的我们不是那么容易,下面是源码,里面有带解释,希望能帮助大家更好的理解:

# -*- coding: utf-8 -*-

import os
import sys
import time
import argparse
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist

FLAGS = None


# placeholder_inputs()函数将生成两个tf.placeholder操作,定义传入图表中的shape参数,
# shape参数中包括batch_size值,后续还会将实际的训练用例传入图表。
# 在训练循环(training loop)的后续步骤中,传入的整个图像和标签数据集会被切片,
# 以符合每一个操作所设置的batch_size值,占位符操作将会填补以符合这个batch_size值。
# 然后使用feed_dict参数,将数据传入sess.run()函数。
def placeholder_inputs(batch_size):
    images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS))
    labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
    return images_placeholder, labels_placeholder


# fill_feed_dict函数会查询给定的DataSet,索要下一批次batch_size的图像和标签,
# 与占位符相匹配的Tensor则会包含下一批次的图像和标签。
def fill_feed_dict(data_set, images_placeholder, labels_placeholder):
    images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)
    # 然后,以占位符为哈希键,创建一个Python字典对象,键值则是其代表的反馈Tensor。
    feed_dict = {
        images_placeholder: images_feed,
        labels_placeholder: labels_feed
    }

    return feed_dict


# 对模型进行评估(eval即evaluation),Eval Output
def do_eval(sess, eval_correct, images_placeholder, labels_placeholder,
            data_set):
    true_count = 0
    steps_per_epoch = data_set.num_examples
    num_examples = steps_per_epoch * FLAGS.batch_size
    for step in range(num_examples):
        feed_dict = fill_feed_dict(data_set,
                                   images_placeholder,
                                   labels_placeholder)
        # 累加所有in_top_k操作判定为正确的预测之和
        true_count += sess.run(eval_correct, feed_dict=feed_dict)

    # 准确率 = 正确测试的总数 除以例子总数
    precision = float(true_count) / num_examples
    print('Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
          (num_examples, true_count, precision))


def run_training():
    # 在run_training()方法的一开始,input_data.read_data_sets()函数会确保你的本地训练文件夹中,
    # 已经下载了正确的数据,然后将这些数据解压并返回一个含有DataSet实例的字典。
    data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)

    # 在默认图表中创建模型
    with tf.Graph().as_default():
        # 生成两个tf.placeholder操作
        images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)

        # 构建从推理模型计算预测的图
        logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2)

        # 将计算损失的操作添加到图表中
        loss = mnist.loss(logits, labels_placeholder)

        # 向图中添加计算和应用梯度的操作
        train_op = mnist.training(loss, FLAGS.learning_rate)

        # 添加Op以在评估期间将逻辑与标签进行比较
        eval_correct = mnist.evaluation(logits, labels_placeholder)

        # 所有的即时数据(在这里只有一个)都要在图表构建阶段合并至一个操作(op)中。
        summary_op = tf.summary.merge_all()

        # 添加变量初始化器Op
        init = tf.global_variables_initializer()

        # 保存检查点(checkpoint)
        # 为了得到可以用来后续恢复模型以进一步训练或评估的检查点文件(checkpoint file),
        # 我们实例化一个tf.train.Saver
        saver = tf.train.Saver()

        sess = tf.Session()
        # 用于写入包含了图表本身和即时数据具体值的事件文件
        summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

        sess.run(init)

        for step in range(FLAGS.max_steps):
            start_time = time.time()

            feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder)

            _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

            duration = time.time() - start_time

            if step % 100 == 0 :
                print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))

                # 每次运行summary_op时,都会往事件文件中写入最新的即时数据
                summary_str = sess.run(summary_op, feed_dict=feed_dict)
                # 函数的输出会传入事件文件读写器(writer)的add_summary()函数
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

            # 保存检查点并定期评估模型
            if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
                # 向训练文件夹中写入包含了当前所有可训练变量值得检查点文件
                saver.save(sess, checkpoint_file, global_step=step)

                print('使用训练数据集对模型进行评估')
                do_eval(
                    sess,
                    eval_correct,
                    images_placeholder,
                    labels_placeholder,
                    data_sets.train
                )
                print('使用验证数据集对模型进行评估')
                do_eval(
                    sess,
                    eval_correct,
                    images_placeholder,
                    labels_placeholder,
                    data_sets.validation
                )
                print('使用测试数据集对模型进行评估')
                do_eval(
                    sess,
                    eval_correct,
                    images_placeholder,
                    labels_placeholder,
                    data_sets.test
                )


def main(_):
    if tf.gfile.Exists(FLAGS.log_dir):
        tf.gfile.DeleteRecursively(FLAGS.log_dir)
    tf.gfile.MakeDirs(FLAGS.log_dir)
    run_training()


if __name__ == '__main__':
    #
    # 创建一个命令行解析模块的解析对象
    parser = argparse.ArgumentParser()
    # add_argumen()第一个是选项,第二个是数据类型,第三个默认值,第四个是help命令时的说明
    parser.add_argument(
        '--learning_rate',
        type=float,
        default=0.01,
        help='Initial leaning rate'
    )

    parser.add_argument(
        '--max_steps',
        type=int,
        default=2000,
        help='Number of steps to run trainer.'
    )

    parser.add_argument(
        '--hidden1',
        type=int,
        default=128,
        help='Number of units in hedden layer 1.'
    )

    parser.add_argument(
        '--hidden2',
        type=int,
        default=32,
        help='Number of units in hedden layer 2.'
    )

    parser.add_argument(
        '--batch_size',
        type=int,
        default=10,
        help='Batch size. divide evenly into the dataset sizes.'
    )

    parser.add_argument(
        '--input_data_dir',
        type=str,
        default=os.path.join(os.getenv('TEST_TMPDIR', 'tmp'), 'tensorflow/mnist/input_data'),
        help='Directory to put the input data.'
    )

    parser.add_argument(
        '--log_dir',
        type=str,
        default=os.path.join(os.getenv('TEST_TMPDIR', 'tmp'), 'tensorflow/mnist/logs/fully_connected_feed'),
        help='Directory to put the log data.'
    )

    parser.add_argument(
        '--fake_data',
        default=False,
        help='If true, uses fake data for unit testing.',
        action='store_true'
    )

    # 有时间一个脚本只需要解析所有命令行参数中的一小部分,
    # 剩下的命令行参数给两一个脚本或者程序
    # 在接受到多余的命令行参数时不报错
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

如果有碰到问题,请在下方进行留言,将第一时间为你解决!

 

【The End!】

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值