Python 命令行参数定义和 Tensorflow 命令行参数定义
1. Python 命令行参数定义
argparse
库, 如:
import argparse
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser()
parser.add_argument('-b', dest='batchsize',default=1)
parser.add_argument('-g', dest='gpuid', default=0)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
print(args.batchsize)
print(args.gpuid)
2. Tensorflow 命令行参数定义
Tensorflow 采用tf.app.flags
来进行命令行参数传递.(要求tensorflow的版本在1.8.0以上)
如 – flags_test.py
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
# Settings for some training parameters.
flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
'Learning rate policy for training.')
flags.DEFINE_float('base_learning_rate', .0001,
'The base learning rate for model training.')
flags.DEFINE_integer('learning_rate_decay_step', 2000,
'Decay the base learning rate at a fixed step.')
flags.DEFINE_integer('train_batch_size', 12,
'The number of images in each batch during training.')
flags.DEFINE_multi_integer('train_crop_size', [513, 513],
'Image crop size [height, width] during training.')
flags.DEFINE_boolean('upsample_logits', True,
'Upsample logits during training.')
flags.DEFINE_string('dataset', 'dataset_name',
'Name of the test dataset.')
def main(_):
print(FLAGS.learning_policy)
print(FLAGS.base_learning_rate)
print(FLAGS.learning_rate_decay_step)
print(FLAGS.train_batch_size)
print(FLAGS.train_crop_size)
print(FLAGS.upsample_logits)
print(FLAGS.dataset)
if __name__ == '__main__':
tf.app.run()
python flags_test.py
转自http://www.aiuai.cn/aifarm258.html