tf.app.run(main=None, argv=None)

tf.app.run(main=None, argv=None)

TensorFlow
https://tensorflow.google.cn/

TensorFlow -> API
https://tensorflow.google.cn/versions

TensorFlow 1.x -> r1.4
https://github.com/tensorflow/docs/tree/r1.4/site/en/api_docs

tf.app.run(main=None, argv=None) 执行程序中的 main(_) 函数,并解析命令行参数。可选参数是 main=None, argv=Noneargv 可理解为 tf.app.run(main=None, argv=None) 输入列表。

1. /usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Generic entry point script."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys as _sys

from tensorflow.python.platform import flags
from tensorflow.python.util.all_util import remove_undocumented


def _benchmark_tests_can_log_memory():
  return True


def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or _sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  _sys.exit(main(_sys.argv[:1] + flags_passthrough))


_allowed_symbols = [
    'run',
    # Allowed submodule.
    'flags',
]

remove_undocumented(__name__, _allowed_symbols)

flags_passthrough = f._parse_flags(args=args) 用来解析命令行参数的函数。

2. tf.app.run()

2.1 调用示例

if __name__ == '__main__':
    tf.app.run()

该语句经常用于自己测试。

文件名定为 my_model.py__name__ 有两种情况:

  1. 当该文件直接运行时,执行 python my_model.py
    __name__ 数值是 __main__,因此会执行 if 内的语句。

  2. my_model.py 文件以模块形式导入时,执行 import my_model
    __name__ 数值是 my_model,因此不会执行 if 内的语句

__name__ = 模块名,每个 Python 模块 (Python 文件) 都包含内置的变量 __name__Python 中的模块名可以分成两类:

  • 第一类是 xxx.py 文件,模块名就是 xxx。例如 yongqiang.py,那么模块名是 yongqiang,调用方法就是 import yongqiang
  • 第二类是 __main__,当直接运行 xxx.py 文件时,缺省调用的模块名是 __main__

if语句用来判断,该模块是正在被 import 还是被 SHELL 单独运行,被 SHELL 单独运行时,为 True,执行 tf.app.run()。在该模块被 import 时,main(_) 为模块中的一个函数。

2.2 函数原型

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""

Runs the program with an optional main function and argv list.
主函数中的 tf.app.run() 会调用 main(_),并传递参数,因此必须在 main(_) 函数中设置一个参数。如果要更换 main(_) 函数为 test(args),只需要在 tf.app.run() 中传入一个指定的函数名即可 tf.app.run(test)

如果你的代码中的入口函数不叫 main(_),而是一个其他名字的函数 test(args),则你应该这样写入口 tf.app.run(test)

2.3 代码解析

f = flags.FLAGS
flags = tf.app.flags
FLAGS = flags.FLAGS

tf.app.flags 用于接收命令行传递参数。

# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None

将传入的参数 argv 从第二个开始切片 copyargs 形成一个列表 (第一个为函数名),如果没有则传 args=None

FLAGS = tf.app.flags.FLAGS 语句存在,表示输入已经解析。tf.app.run()argv=None,通过 args = argv[1:] if argv else None 语句,可知 args=None (即不指定,后面会自动解析 command)。

# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access

f = flags.FLAGS 构造了解析器 f 用以解析 argsflags_passthrough = f._parse_flags(args=args) 解析 args 列表或者 command 输入。args 列表为空,则解析 command 输入,返回的 flags_passthrough 内为无法解析的数据列表 (不包括文件名)。

main = main or sys.modules['__main__'].main 默认执行参数中指定的 main 函数。若 main=None,则默认程序中 main(_) 函数。如果 run(main=None, argv=None) 函数的 main 参数非 None,则 main 为输入参数 main。反之,读取代码中的 main() 函数为 main(_)

在没有传入主函数参数时,就认为当前模块中已经有一个叫 main(_) 的主函数,将 main(_) 赋给等号左边 main。在传入主函数参数时,将传入的当前模块自己定义的主函数传给等号左边的 main

The first main in right side of = is the first argument of current function run(main=None, argv=None). While sys.modules['__main__'] means current running file (e.g. my_model.py). 有以下两种情况:
(1) 如果 my_model.py 中没有 main(_) 函数,则应该向 tf.app.run 中输入自定义的主函数 tf.app.run(main=my_main_running_function)
(2) 如果 my_model.py 中包含 main(_) 函数,则当参数 mainNone 时,_sys.modules['__main__'].main 获取 my_model.pymain(_) 函数。

sys.exit(main(sys.argv[:1] + flags_passthrough)) 调用 main 函数,参数为文件名 + 无法解析数据的列表。定义 main 函数时需要设置参数,def main(_): 是正确的,def main(): 是不正确的。

如果是从其它模块调用该模块程序,则不会运行 main(_) 函数。如果直接运行该模块程序,则会运行 main(_) 函数。如果此文件被其他文件 import 的时候,不会执行 main(_) 函数。

3 example

3.1 命令行执行

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import tensorflow as tf

flags = tf.app.flags
FLAGS = flags.FLAGS

tf.app.flags.DEFINE_integer('train_batch_size', 12, 'The number of images in each batch during training.')
tf.app.flags.DEFINE_boolean("is_train", True, "")


def main(_):
    print("{}".format(FLAGS.train_batch_size))
    print("{}".format(FLAGS.is_train))
    print(_)


if __name__ == '__main__':
    tf.app.run()
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py -h
usage: test.py [-h] [--train_batch_size TRAIN_BATCH_SIZE]
               [--is_train [IS_TRAIN]] [--nois_train]

optional arguments:
  -h, --help            show this help message and exit
  --train_batch_size TRAIN_BATCH_SIZE
                        The number of images in each batch during training.
  --is_train [IS_TRAIN]
  --nois_train
strong@foreverstrong:~/git_workspace/MonoGRNet$ 
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py
12
True
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$ 
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64
64
True
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$ 
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64 --is_train False
64
False
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$ 
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64 --is_train False --gpus 0
64
False
['test.py', '--gpus', '0']
strong@foreverstrong:~/git_workspace/MonoGRNet$

3.2 PyCharm 执行

  • 参数设置

在这里插入图片描述

  • 调试过程

在这里插入图片描述

  • 调试过程

在这里插入图片描述

  • 调试过程

在这里插入图片描述

  • 运行结果
/usr/bin/python2.7 /home/strong/git_workspace/MonoGRNet/test.py --train_batch_size 64 --is_train False --gpus 0
64
False
['/home/strong/git_workspace/MonoGRNet/test.py', '--gpus', '0']

Process finished with exit code 0
import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进
05-26
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yongqiang Cheng

梦想不是浮躁,而是沉淀和积累。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值