题外话:最近读开源的tensorflow程序代码,发现了好用的flag,感觉有必要学习一下。
在程序之中我们通常都会用到很多参数变量、常量、或者标志量来辅助程序的流程。例如,在规定神经网络的输入中,我们需要规定输入的尺寸width、height等等;我们也可以通过定义is_training标志量来判断当前流程是训练阶段还是测试阶段。Tensorflow中的tf.app.flags可以很好地统一“管理”用户自定义的多个参数量,并且可以全局调用。
tf.app.flags的官方定义地址:https://www.tensorflow.org/api_docs/python/tf/app/flags 指向github的 https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/platform/flags.py,其中给出了flags的具体实现定义,代码我就不搬运了。
FLAGS的使用方法
下面这个示例代码可以很好地解释如何在程序中使用 tf.app.flags。
import tensorflow as tf
# 定义输入尺寸的参数 input_size 为整型,值为128
tf.app.flags.DEFINE_integer('input_size', 128, '')
# 定义数据的batch_size 为整型,值为32
tf.app.flags.DEFINE_integer('batch_size', 32, '')
# 定义一个浮点型参数,学习率learning_rate为0.001
tf.app.flags.DEFINE_float('learning_rate', 0.001, '')
# 定义循环迭代最大次数为1000
tf.app.flags.DEFINE_integer('max_iteration', 1000, '')
# 定义当前阶段是否为训练,is_training为bool型数据,在代码运行过程中通常会根据迭代次数变化
tf.app.flags.DEFINE_boolean('is_training', True, 'point out to train or test the model')
# 定义保存中间训练结果的迭代次数,为整型100,意味着模型每循环100次则本地保存对应的中间训练结果
tf.app.flags.DEFINE_integer('save_steps', 100, '')
# 定义一个字符串类型的变量,指向预训练模型存储的地址
tf.app.flags.DEFINE_string('pretrained_model_directory', None, 'point to the pretrained model')
# 声明变量FLAGS
FLAGS = tf.app.flags.FLAGS
# 即可通过调用变量名获取对应之前定义的参数值
print("learning rate: ", FLAGS.learning_rate)
需要留意的是定义的参数类型确定后,须调用对应的变量类型的定义API,如DEFINE_float等等。若数据类型和实际需求不对应,后续算法运行过程中可能会报错。