为方便调参以及更换目录等,常将参数放到命令行中,常见有两种方式
#利用argparse
import argparse
import sys
#创建解析器参数,可以添加参数,可以禁用 add_help:默认是True,可以设置False禁用
#parser=argparse.ArgumentParser(description="This is a example program ")
"""
在执行程序的时候,定位参数必选,可选参数可选。
add_argument()常用的参数:
dest:如果提供dest,例如dest="a",那么可以通过args.a访问该参数
default:设置参数的默认值
action:参数触发的动作
store:保存参数,默认
store_const:保存一个被定义为参数规格一部分的值(常量),而不是一个来自参数解析而来的值。
store\_ture/store\_false:保存相应的布尔值
append:将值保存在一个列表中。
append_const:将一个定义在参数规格中的值(常量)保存在一个列表中。
count:参数出现的次数
parser.add_argument("-v", "--verbosity", action="count", default=0, help="increase output verbosity")
version:打印程序版本信息
type:把从命令行输入的结果转成设置的类型
choice:允许的参数值
parser.add_argument("-v", "--verbosity", type=int, choices=\[0, 1, 2\], help="increase output verbosity")
help:参数命令的介绍
"""
parser = argparse.ArgumentParser()
parser.add_argument('--fake_data', nargs='?', const=True, type=bool,
default=False,
help='If true, uses fake data for unit testing.')
parser.add_argument('--max_steps', type=int, default=1000,
help='Number of steps to run trainer.')
parser.add_argument('--learning_rate', type=float, default=0.001,
help='Initial learning rate')
parser.add_argument('--dropout', type=float, default=0.9,
help='Keep probability for training dropout.')
parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
parser.add_argument('--log_dir', type=str, default='/tmp/tensorflow/mnist/logs/mnist\_with\_summaries',
help='Summaries log directory')
FLAGS, unparsed = parser.parse\_known\_args()
print(FLAGS)
#Namespace(data\_dir='/tmp/tensorflow/mnist/input\_data', dropout=0.9, fake\_data=False, learning\_rate=0.001, log\_dir='/tmp/tensorflow/mnist/logs/mnist\_with\_summaries', max\_steps=1000)
print(unparsed) #\[\]
#tf.app.run(main=main, argv=\[sys.argv\[0\]\] + unparsed)
args = parser.parse_args()
print(args.log_dir) #/tmp/tensorflow/mnist/logs/mnist\_with\_summaries
\# python3 test.py --dropout 0.1
#===========================================================
#利用tf.app.flags组件
import tensorflow as tf
import logging
\# Define hyperparameters
#定义一个tf.app.flags对象
flags = tf.app.flags
FLAGS = flags.FLAGS
#调用自带的设置不同的类型
flags.DEFINE_boolean("enable\_colored\_log", False, "Enable colored log")
flags.DEFINE_string("validate\_tfrecords\_file",
"./data/a8a/a8a_test.libsvm.tfrecords",
"The glob pattern of validate TFRecords files")
flags.DEFINE_integer("label_size", 2, "Number of label size")
flags.DEFINE_float("learning_rate", 0.01, "The learning rate")
def main():
\# Get hyperparameters
if FLAGS.enable\_colored\_log:
import coloredlogs
coloredlogs.install()
logging.basicConfig(level=logging.INFO)
LABEL_SIZE = FLAGS.label_size
print(LABEL_SIZE)
return 0
if \_\_name\_\_ == "\_\_main\_\_":
main()
"""
用tf运行用以上两个都可,用spyder则只能用第一种
tf.app.run()函数源码如下:
from \_\_future\_\_ import absolute_import
from \_\_future\_\_ import division
from \_\_future\_\_ import print_function
import sys
from tensorflow.python.platform import flags
def run(main=None):
f = flags.FLAGS
f.\_parse\_flags()
main = main or sys.modules\['\_\_main\_\_'\].main
sys.exit(main(sys.argv))
"""
转载于:https://my.oschina.net/u/3726752/blog/3089271