tf.app.run(main=None, argv=None)
tf.app.run(main=None, argv=None)
TensorFlow
https://tensorflow.google.cn/
TensorFlow -> API
https://tensorflow.google.cn/versions
TensorFlow 1.x -> r1.4
https://github.com/tensorflow/docs/tree/r1.4/site/en/api_docs
tf.app.run(main=None, argv=None)
执行程序中的 main(_)
函数,并解析命令行参数。可选参数是 main=None, argv=None
,argv
可理解为 tf.app.run(main=None, argv=None)
输入列表。
1. /usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic entry point script."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys as _sys
from tensorflow.python.platform import flags
from tensorflow.python.util.all_util import remove_undocumented
def _benchmark_tests_can_log_memory():
return True
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or _sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
_allowed_symbols = [
'run',
# Allowed submodule.
'flags',
]
remove_undocumented(__name__, _allowed_symbols)
flags_passthrough = f._parse_flags(args=args)
用来解析命令行参数的函数。
2. tf.app.run()
2.1 调用示例
if __name__ == '__main__':
tf.app.run()
该语句经常用于自己测试。
文件名定为 my_model.py
,__name__
有两种情况:
-
当该文件直接运行时,执行
python my_model.py
__name__
数值是__main__
,因此会执行if
内的语句。 -
当
my_model.py
文件以模块形式导入时,执行import my_model
__name__
数值是my_model
,因此不会执行if
内的语句
__name__
= 模块名,每个 Python
模块 (Python
文件) 都包含内置的变量 __name__
,Python
中的模块名可以分成两类:
- 第一类是
xxx.py
文件,模块名就是xxx
。例如yongqiang.py
,那么模块名是yongqiang
,调用方法就是import yongqiang
。 - 第二类是
__main__
,当直接运行xxx.py
文件时,缺省调用的模块名是__main__
。
该 if
语句用来判断,该模块是正在被 import
还是被 SHELL 单独运行,被 SHELL 单独运行时,为 True
,执行 tf.app.run()
。在该模块被 import
时,main(_)
为模块中的一个函数。
2.2 函数原型
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
Runs the program with an optional main
function and argv
list.
主函数中的 tf.app.run()
会调用 main(_)
,并传递参数,因此必须在 main(_)
函数中设置一个参数。如果要更换 main(_)
函数为 test(args)
,只需要在 tf.app.run()
中传入一个指定的函数名即可 tf.app.run(test)
。
如果你的代码中的入口函数不叫 main(_)
,而是一个其他名字的函数 test(args)
,则你应该这样写入口 tf.app.run(test)
。
2.3 代码解析
f = flags.FLAGS
flags = tf.app.flags
FLAGS = flags.FLAGS
tf.app.flags
用于接收命令行传递参数。
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
将传入的参数 argv
从第二个开始切片 copy
到 args
形成一个列表 (第一个为函数名),如果没有则传 args=None
。
FLAGS = tf.app.flags.FLAGS
语句存在,表示输入已经解析。tf.app.run()
中 argv=None
,通过 args = argv[1:] if argv else None
语句,可知 args=None
(即不指定,后面会自动解析 command)。
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
f = flags.FLAGS
构造了解析器 f
用以解析 args
,flags_passthrough = f._parse_flags(args=args)
解析 args
列表或者 command 输入。args
列表为空,则解析 command 输入,返回的 flags_passthrough
内为无法解析的数据列表 (不包括文件名)。
main = main or sys.modules['__main__'].main
默认执行参数中指定的 main
函数。若 main=None
,则默认程序中 main(_)
函数。如果 run(main=None, argv=None)
函数的 main
参数非 None
,则 main
为输入参数 main
。反之,读取代码中的 main()
函数为 main(_)
。
在没有传入主函数参数时,就认为当前模块中已经有一个叫 main(_)
的主函数,将 main(_)
赋给等号左边 main
。在传入主函数参数时,将传入的当前模块自己定义的主函数传给等号左边的 main
。
The first main
in right side of =
is the first argument of current function run(main=None, argv=None)
. While sys.modules['__main__']
means current running file (e.g. my_model.py
). 有以下两种情况:
(1) 如果 my_model.py
中没有 main(_)
函数,则应该向 tf.app.run
中输入自定义的主函数 tf.app.run(main=my_main_running_function)
。
(2) 如果 my_model.py
中包含 main(_)
函数,则当参数 main
为 None
时,_sys.modules['__main__'].main
获取 my_model.py
的 main(_)
函数。
sys.exit(main(sys.argv[:1] + flags_passthrough))
调用 main
函数,参数为文件名 + 无法解析数据的列表。定义 main
函数时需要设置参数,def main(_):
是正确的,def main():
是不正确的。
如果是从其它模块调用该模块程序,则不会运行 main(_)
函数。如果直接运行该模块程序,则会运行 main(_)
函数。如果此文件被其他文件 import
的时候,不会执行 main(_)
函数。
3 example
3.1 命令行执行
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
tf.app.flags.DEFINE_integer('train_batch_size', 12, 'The number of images in each batch during training.')
tf.app.flags.DEFINE_boolean("is_train", True, "")
def main(_):
print("{}".format(FLAGS.train_batch_size))
print("{}".format(FLAGS.is_train))
print(_)
if __name__ == '__main__':
tf.app.run()
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py -h
usage: test.py [-h] [--train_batch_size TRAIN_BATCH_SIZE]
[--is_train [IS_TRAIN]] [--nois_train]
optional arguments:
-h, --help show this help message and exit
--train_batch_size TRAIN_BATCH_SIZE
The number of images in each batch during training.
--is_train [IS_TRAIN]
--nois_train
strong@foreverstrong:~/git_workspace/MonoGRNet$
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py
12
True
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64
64
True
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64 --is_train False
64
False
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64 --is_train False --gpus 0
64
False
['test.py', '--gpus', '0']
strong@foreverstrong:~/git_workspace/MonoGRNet$
3.2 PyCharm 执行
- 参数设置
- 调试过程
- 调试过程
- 调试过程
- 运行结果
/usr/bin/python2.7 /home/strong/git_workspace/MonoGRNet/test.py --train_batch_size 64 --is_train False --gpus 0
64
False
['/home/strong/git_workspace/MonoGRNet/test.py', '--gpus', '0']
Process finished with exit code 0