Tensorflow | 常用 API 之 FLAGS —— tf.app.flags

 

题外话:最近读开源的tensorflow程序代码,发现了好用的flag,感觉有必要学习一下。

 

在程序之中我们通常都会用到很多参数变量、常量、或者标志量来辅助程序的流程。例如,在规定神经网络的输入中,我们需要规定输入的尺寸width、height等等;我们也可以通过定义is_training标志量来判断当前流程是训练阶段还是测试阶段。Tensorflow中的tf.app.flags可以很好地统一“管理”用户自定义的多个参数量,并且可以全局调用。

 

tf.app.flags的官方定义地址:https://www.tensorflow.org/api_docs/python/tf/app/flags 指向github的 https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/platform/flags.py,其中给出了flags的具体实现定义,代码我就不搬运了。

 

FLAGS的使用方法

下面这个示例代码可以很好地解释如何在程序中使用 tf.app.flags。

import tensorflow as tf

# 定义输入尺寸的参数 input_size 为整型,值为128
tf.app.flags.DEFINE_integer('input_size', 128, '')

# 定义数据的batch_size 为整型,值为32
tf.app.flags.DEFINE_integer('batch_size', 32, '')

# 定义一个浮点型参数,学习率learning_rate为0.001
tf.app.flags.DEFINE_float('learning_rate', 0.001, '')

# 定义循环迭代最大次数为1000
tf.app.flags.DEFINE_integer('max_iteration', 1000, '')

# 定义当前阶段是否为训练,is_training为bool型数据,在代码运行过程中通常会根据迭代次数变化
tf.app.flags.DEFINE_boolean('is_training', True, 'point out to train or test the model')

# 定义保存中间训练结果的迭代次数,为整型100,意味着模型每循环100次则本地保存对应的中间训练结果
tf.app.flags.DEFINE_integer('save_steps', 100, '')

# 定义一个字符串类型的变量,指向预训练模型存储的地址
tf.app.flags.DEFINE_string('pretrained_model_directory', None, 'point to the pretrained model')



# 声明变量FLAGS
FLAGS = tf.app.flags.FLAGS

# 即可通过调用变量名获取对应之前定义的参数值
print("learning rate: ", FLAGS.learning_rate)


需要留意的是定义的参数类型确定后,须调用对应的变量类型的定义API,如DEFINE_float等等。若数据类型和实际需求不对应,后续算法运行过程中可能会报错。

 

 

 

 

 

 

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值