说道命令行参数解析,就不得不提到 python 的 argparse 模块,详情可参考博主之前的一篇博客:python argparse 模块命令行参数解析。在阅读相关工程的源码时,很容易发现 tf.app.flags 模块的身影。其作用与 python 的 argparse 类似。
一、tf.app.flags.FLAGS
例子1:
新建一个名为:app_flags.py 的文件。
#coding:utf-8
# 学习使用 tf.app.flags 使用,全局变量
# 可以再命令行中运行也是比较方便,如果只写 python app_flags.py 则代码运行时默认程序里面设置的默认设置
# 若 python app_flags.py --train_data_path <绝对路径 train.txt> --max_sentence_len 100
# --embedding_size 100 --learning_rate 0.05 代码再执行的时候将会按照上面的参数来运行程序
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string("train_data_path", "/home/yongcai/chinese_fenci/train.txt", "training data dir")
tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir")
tf.app.flags.DEFINE_integer("max_sentence_len", 80, "max num of tokens per query")
tf.app.flags.DEFINE_integer("embedding_size", 50, "embedding size")
tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate")
def main(unused_argv):
#注意:这里的FLAGS.train_data_path已经是更新过的了
train_data_path = FLAGS.train_data_path
print("train_data_path", train_data_path)
max_sentence_len = FLAGS.max_sentence_len
print("max_sentence_len", max_sentence_len)
embdeeing_size = FLAGS.embedding_size
print("embedding_size", embdeeing_size)
abc = tf.add(max_sentence_len, embdeeing_size)
init = tf.global_variables_initializer()
#with tf.Session() as sess:
#sess.run(init)
#print("abc", sess.run(abc))
sv = tf.train.Supervisor(logdir=FLAGS.log_dir, init_op=init)
with sv.managed_session() as sess:
print("abc:", sess.run(abc))
# sv.saver.save(sess, "/home/yongcai/tmp/")
# 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数
if __name__ == '__main__':
tf.app.run() # 解析命令行参数,调用main 函数 main(sys.argv)
调用方法:
其中参数可以根据需求进行修改。
python app_flags.py --train_data_path <绝对路径 train.txt>
--max_sentence_len 100 --embedding_size 100 --learning_rate 0.05
如果这样调用:
python app_flags.py
则会执行程序时会自动调用程序中 default 中的参数。
注意:main(unused_argv)函数必须有参数,为无法解析出的sys.argv
如命令行输出:python app_flags.py --aa --max_sentence_len=60
则,unused_argv为['app_flags.py', '--aa']
例子2:
tf定义了tf.app.flags,用于支持接受命令行传递参数,相当于接受argv。
tf.app.run() 会自动调用main(_)函数,同时,将文件名和未解析出的内容传到main函数的参数中。
如第二次输出--aaaaaa 99未解析出,则放入main参数"_"中。
import tensorflow as tf
a=tf.get_default_graph()
import tensorflow as tf
#第一个是参数名称,第二个参数是默认值,第三个是参数描述
tf.app.flags.DEFINE_string('str_name', 'def_v_1',"descrip1")
tf.app.flags.DEFINE_integer('int_name', 10,"descript2")
tf.app.flags.DEFINE_boolean('bool_name', False, "descript3")
FLAGS = tf.app.flags.FLAGS
#必须带参数,否则:'TypeError: main() takes no arguments (1 given)'; main的参数名随意定义,无要求
def main(_):
print(FLAGS.str_name)
print(FLAGS.int_name)
print(FLAGS.bool_name)
print(_)
if __name__ == '__main__':
tf.app.run() #执行main函数
执行:
D:\Spyderprojects\test>python test3.py
def_v_1
10
False
['test3.py']
D:\Spyderprojects\test>python test3.py --int_name 100 --aaaaaaaa 10
def_v_1
100
False
['test3.py', '--aaaaaaaa', '10']
二、详解
1、使用tf.app.flags.FLAGS和tf.app.run()
tensorflow的程序中,在main函数下,都是使用tf.app.run()来启动
if __name__ == "__main__":
tf.app.run()
查看源码可知,该函数是用来处理flag解析,然后执行main函数,那么flag解析是什么意思呢?诸如这样的:
可选参数是main=None, argv=None,argv是输入列表,有两种情况,
如上例,使用tf.app.run() ,上面已经有FLAGS = tf.app.flags.FLAGS了,则已经解析了输入。
则tf.app.run() 中argv=None,通过args = argv[1:] if argv else None则args=None(即不指定,后面会自动解析command)
f = flags.FLAGS构造了解析器f用以解析args, f._parse_flags(参数args)解析args列表或者command输入,args列表为空,则解析command输入,返回的flags_passthrough内为无法解析的数据列表(不包括文件名) 。
main = main or sys.modules['__main__'].main默认执行参数中指定的main函数,如main=None,则默认程序中main()函数
最后一步,main(sys.argv[:1] + flags_passthrough),调用main函数,参数为文件名+无法解析数据的列表
import sys
from tensorflow.python.platform import flags
def run(main=None, argv=None):
"""可选参数:‘main’和'argv'list。"""
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS #构造了一个解析器f
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
sys.exit(main(sys.argv[:1] + flags_passthrough))
2、使用argparse和tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
《tensorflow中的例子》
FLAGS=None先重置一下FLAGS列表
if __name__=="__main__":中首先使用parse=argparse.ArgumentParser()加载参数解析器,
使用FLAGS, unparsed=parse.parse_known_args()解析command输入,FLAGS为解析出的命名空间,unparsed为未解析出的输入列表。
调用tf.app.run(main=main, argv=[sys.argv[0]] + unparsed),指定执行函数为main,argv为文件名和未解析输入的列表
在tf.app.run中,使用flags又解析了一次。其中的args为刨除文件名的未解析输入列表
在之后的程序中,还是主要使用FLAGS,这段代码使用tf.app.run的作用仅仅是指定main主函数和使用flags再解析一次输入
import tensorflow as tf
#导入命令行解析模块
import argparse
import sys
FLAGS=None
def main(_):
print(sys.argv[0])
print(FLAGS.dataDir) #\mnist\inputData
#用这种方式保证了,如果此文件被其他文件import的时候,不会执行main中的代码
if __name__=="__main__":
#创建对象
parse=argparse.ArgumentParser()
#增加命令行
parse.add_argument('--dataDir',type=str,default='\\tmp\\tensorflow\\mnist\\inputData',
help='Directory for string input data')
FLAGS, unparsed=parse.parse_known_args()
#解析命令行参数,调用main函数 main(sys.argv)
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)