TensorFlow中的小知识:tf.flags.DEFINE_xxx()

读别人家的代码的时候经常看到这个,结果两三天不看居然忘记了,这脑子绝对上锈了,决定记下来免得老是查来查去的。。。
内容包含如下几个我们经常看到的几个函数:
①tf.flags.DEFINE_xxx()
②FLAGS = tf.flags.FLAGS
③FLAGS._parse_flags()


简单的说:

用于帮助我们添加命令行的可选参数。
也就是说利用该函数我们可以实现在命令行中选择需要设定的参数来运行程序,
可以不用反复修改源代码中的参数,直接在命令行中进行参数的设定。

举个栗子:

程序train.py文件中的小部分代码如下所示:

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string('name', 'default', 'name of the model')
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')
tf.flags.DEFINE_integer('num_layers', 2, 'number of lstm layers')
tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding')
tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding')
tf.flags.DEFINE_float('learning_rate', 0.001, 'learning_rate')
tf.flags.DEFINE_float('train_keep_prob', 0.5, 'dropout rate during training')
tf.flags.DEFINE_string('input_file', '', 'utf8 encoded text file')
tf.flags.DEFINE_integer('max_steps', 100000, 'max steps to train')
tf.flags.DEFINE_integer('save_every_n', 1000, 'save the model every n steps')
tf.flags.DEFINE_integer('log_every_n', 10, 'log to the screen every n steps')
tf.flags.DEFINE_integer('max_vocab', 3500, 'max char number')
#全局参数设置,显示在命令行

在命令行中我们为了执行train.py文件,在命令行中输入:

python train.py \
  --input_file data/shakespeare.txt  \
  --name shakespeare \
  --num_steps 50 \
  --num_seqs 32 \
  --learning_rate 0.01 \
  --max_steps 20000

通过输入不同的文件名、参数,可以快速完成程序的调参和更换训练集的操作,不需要进入源码中更改。

备注:在此感谢上述代码的作者


实践操作一下:

现在我们有如下代码:

import tensorflow as tf
#取上述代码中一部分进行实验
tf.flags.DEFINE_integer('num_seqs', 100, 'number of seqs in one batch')
tf.flags.DEFINE_integer('num_steps', 100, 'length of one seq')
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')

#通过print()确定下面内容的功能
FLAGS = tf.flags.FLAGS #FLAGS保存命令行参数的数据
FLAGS._parse_flags() #将其解析成字典存储到FLAGS.__flags中
print(FLAGS.__flags)

print(FLAGS.num_seqs)

print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")

尝试执行一下上述代码了解其各行代码的功能,可能因为tensorflow版本原因出现报错现象。
查看解决办法可点击链接

展开阅读全文

没有更多推荐了,返回首页