TensorFlow 中 tf.app.flags.FLAGS 的用法介绍

说道命令行参数解析,就不得不提到 python 的 argparse 模块,详情可参考博主之前的一篇博客:python argparse 模块命令行参数解析。在阅读相关工程的源码时,很容易发现 tf.app.flags 模块的身影。其作用与 python 的 argparse 类似。

一、tf.app.flags.FLAGS

例子1:

新建一个名为:app_flags.py 的文件。

#coding:utf-8
 
# 学习使用 tf.app.flags 使用,全局变量
# 可以再命令行中运行也是比较方便,如果只写 python app_flags.py 则代码运行时默认程序里面设置的默认设置
# 若 python app_flags.py --train_data_path <绝对路径 train.txt> --max_sentence_len 100
#    --embedding_size 100 --learning_rate 0.05  代码再执行的时候将会按照上面的参数来运行程序
 
import tensorflow as tf
 
FLAGS = tf.app.flags.FLAGS
 
# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string("train_data_path", "/home/yongcai/chinese_fenci/train.txt", "training data dir")
tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir")
tf.app.flags.DEFINE_integer("max_sentence_len", 80, "max num of tokens per query")
tf.app.flags.DEFINE_integer("embedding_size", 50, "embedding size")
tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate")
 
def main(unused_argv):
    #注意:这里的FLAGS.train_data_path已经是更新过的了
    train_data_path = FLAGS.train_data_path
    print("train_data_path", train_data_path)
    max_sentence_len = FLAGS.max_sentence_len
    print("max_sentence_len", max_sentence_len)
    embdeeing_size = FLAGS.embedding_size
    print("embedding_size", embdeeing_size)
    abc = tf.add(max_sentence_len, embdeeing_size)
 
    init = tf.global_variables_initializer()
 
    #with tf.Session() as sess:
        #sess.run(init)
        #print("abc", sess.run(abc))
 
    sv = tf.train.Supervisor(logdir=FLAGS.log_dir, init_op=init)
    with sv.managed_session() as sess:
        print("abc:", sess.run(abc))
 
        # sv.saver.save(sess, "/home/yongcai/tmp/")
 
# 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数
if __name__ == '__main__':
    tf.app.run()   # 解析命令行参数,调用main 函数 main(sys.argv)

调用方法:

其中参数可以根据需求进行修改。

python app_flags.py --train_data_path <绝对路径 train.txt> 
     --max_sentence_len 100 --embedding_size 100 --learning_rate 0.05

如果这样调用:

python app_flags.py

则会执行程序时会自动调用程序中 default 中的参数。

注意:main(unused_argv)函数必须有参数,为无法解析出的sys.argv

如命令行输出:python app_flags.py --aa --max_sentence_len=60

则,unused_argv为['app_flags.py', '--aa']

例子2:

tf定义了tf.app.flags,用于支持接受命令行传递参数,相当于接受argv。

tf.app.run() 会自动调用main(_)函数,同时,将文件名和未解析出的内容传到main函数的参数中。

如第二次输出--aaaaaa 99未解析出,则放入main参数"_"中。

import tensorflow as tf
a=tf.get_default_graph()
import tensorflow as tf
 
#第一个是参数名称,第二个参数是默认值,第三个是参数描述
tf.app.flags.DEFINE_string('str_name', 'def_v_1',"descrip1")
tf.app.flags.DEFINE_integer('int_name', 10,"descript2")
tf.app.flags.DEFINE_boolean('bool_name', False, "descript3")
 
FLAGS = tf.app.flags.FLAGS
 
#必须带参数,否则:'TypeError: main() takes no arguments (1 given)';   main的参数名随意定义,无要求
def main(_):  
    print(FLAGS.str_name)
    print(FLAGS.int_name)
    print(FLAGS.bool_name)
    print(_)
 
if __name__ == '__main__':
    tf.app.run()  #执行main函数

执行:

D:\Spyderprojects\test>python test3.py
def_v_1
10
False
['test3.py']
D:\Spyderprojects\test>python test3.py --int_name 100 --aaaaaaaa 10
def_v_1
100
False
['test3.py', '--aaaaaaaa', '10']

二、详解

1、使用tf.app.flags.FLAGS和tf.app.run()

tensorflow的程序中,在main函数下,都是使用tf.app.run()来启动

if __name__ == "__main__":  
  tf.app.run()  

查看源码可知,该函数是用来处理flag解析,然后执行main函数,那么flag解析是什么意思呢?诸如这样的:

可选参数是main=None, argv=None,argv是输入列表,有两种情况,

如上例,使用tf.app.run()上面已经有FLAGS = tf.app.flags.FLAGS了,则已经解析了输入。

则tf.app.run() 中argv=None,通过args = argv[1:] if argv else None则args=None(即不指定,后面会自动解析command)

f = flags.FLAGS构造了解析器f用以解析args, f._parse_flags(参数args)解析args列表或者command输入,args列表为空,则解析command输入,返回的flags_passthrough内为无法解析的数据列表(不包括文件名) 。

main = main or sys.modules['__main__'].main默认执行参数中指定的main函数,如main=None,则默认程序中main()函数

最后一步,main(sys.argv[:1] + flags_passthrough),调用main函数,参数为文件名+无法解析数据的列表

import sys
from tensorflow.python.platform import flags
 
def run(main=None, argv=None):
  """可选参数:‘main’和'argv'list。"""
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS             #构造了一个解析器f
 
  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None
  
  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access
 
  main = main or sys.modules['__main__'].main
 
  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

2、使用argparse和tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 

《tensorflow中的例子》

FLAGS=None先重置一下FLAGS列表

if __name__=="__main__":中首先使用parse=argparse.ArgumentParser()加载参数解析器,

使用FLAGS, unparsed=parse.parse_known_args()解析command输入,FLAGS为解析出的命名空间,unparsed为未解析出的输入列表。

调用tf.app.run(main=main, argv=[sys.argv[0]] + unparsed),指定执行函数为main,argv为文件名和未解析输入的列表

在tf.app.run中,使用flags又解析了一次。其中的args为刨除文件名的未解析输入列表

在之后的程序中,还是主要使用FLAGS,这段代码使用tf.app.run的作用仅仅是指定main主函数和使用flags再解析一次输入

import tensorflow as tf
#导入命令行解析模块
import argparse
import sys

FLAGS=None

def main(_):
    print(sys.argv[0])
    print(FLAGS.dataDir)   #\mnist\inputData

#用这种方式保证了,如果此文件被其他文件import的时候,不会执行main中的代码
if __name__=="__main__": 
    #创建对象
    parse=argparse.ArgumentParser()
    #增加命令行
    parse.add_argument('--dataDir',type=str,default='\\tmp\\tensorflow\\mnist\\inputData',
                    help='Directory for string input data')
    FLAGS, unparsed=parse.parse_known_args()
    #解析命令行参数,调用main函数 main(sys.argv)
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 

  • 3
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值