TensorFlow中关于tf.app.flags命令行参数解析模块

tf.app.flags命令行参数解析模块

说道命令行参数解析,就不得不提到 python 的 argparse 模块,详情可参考我之前的一篇文章:python argparse 模块命令行参数用法及说明

在阅读相关工程的源码时,很容易发现 tf.app.flags 模块的身影。其作用与 python 的 argparse 类似。

直接上代码实例,新建一个名为 test_flags.py 的文件,内容如下:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

#coding:utf-8

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

# tf.app.flags.DEFINE_string("param_name", "default_val", "description")

tf.app.flags.DEFINE_string("train_data_path", "/home/feige", "training data dir")

tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir")

tf.app.flags.DEFINE_integer("train_batch_size", 128, "batch size of train data")

tf.app.flags.DEFINE_integer("test_batch_size", 64, "batch size of test data")

tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate")

def main(unused_argv):

    train_data_path = FLAGS.train_data_path

    print("train_data_path", train_data_path)

    train_batch_size = FLAGS.train_batch_size

    print("train_batch_size", train_batch_size)

    test_batch_size = FLAGS.test_batch_size

    print("test_batch_size", test_batch_size)

    size_sum = tf.add(train_batch_size, test_batch_size)

    with tf.Session() as sess:

        sum_result = sess.run(size_sum)

        print("sum_result", sum_result)

# 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数

if __name__ == '__main__':

    tf.app.run()   # 解析命令行参数,调用main 函数 main(sys.argv)

上述代码已给出较为详细的注释,在此不再赘述。

该文件的调用示例以及运行结果如下所示

如果需要修改默认参数的值,则在命令行传入自定义参数值即可,若全部使用默认参数值,则可直接在命令行运行该 python 文件。

读者可能会对 tf.app.run() 有些疑问,在上述注释中也有所解释,但要真正弄清楚其运行原理

还需查阅其源代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

def run(main=None, argv=None):

  """Runs the program with an optional 'main' function and 'argv' list."""

  f = flags.FLAGS

  # Extract the args from the optional `argv` list.

  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command

  # line otherwise.

  # pylint: disable=protected-access

  flags_passthrough = f._parse_flags(args=args)

  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments

  # to the final program.

  sys.exit(main(sys.argv[:1] + flags_passthrough))

flags_passthrough=f._parse_flags(args=args)这里的_parse_flags就是我们tf.app.flags源码中用来解析命令行参数的函数。

所以这一行就是解析参数的功能;

下面两行代码也就是 tf.app.run 的核心意思:执行程序中 main 函数,并解析命令行参数!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

sinat_40572875

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值