tensorflow命令行参数源代码分析

原帖:原帖

一、认识tensorflow命令行参数

在深度学习训练中,我们常常需要动态的配置诸如batch size、learning rate、epoch、kernel size等等超参数,同时在分布式训练时为了区别运行不同的代码,我们也需要配置一个参数用以运行不同代码。那么有无一种比较合适的可以动态配置的方法呢?答案是肯定的,一种是使用python的argparse库,另外一种是使用tensorflow的tf.app.flags组件,今天我们要讲的是后者。

tf.app.flags其实是tensorflow定义的一个类,本质上是基于argparse再封装更友好直观的一个类库,其主要源代码位于“tensorflow/tensorflow/python/platform/flags.py”,flags.py定义了如何构造和解析命令行参数,通过分析源代码及其注解,我们可以很直观的明了其原理,其源代码如下:

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Implementation of the flags interface."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse as _argparse

from tensorflow.python.platform import tf_logging as _logging
from tensorflow.python.util.all_util import remove_undocumented

_global_parser = _argparse.ArgumentParser()


# pylint: disable=invalid-name


class _FlagValues(object):
  """Global container and accessor for flags and their values."""

  def __init__(self):
    self.__dict__['__flags'] = {}
    self.__dict__['__parsed'] = False
    self.__dict__['__required_flags'] = set()

  def _parse_flags(self, args=None):
    result, unparsed = _global_parser.parse_known_args(args=args)
    for flag_name, val in vars(result).items():
      self.__dict__['__flags'][flag_name] = val
    self.__dict__['__parsed'] = True
    self._assert_all_required()
    return unparsed

  def __getattr__(self, name):
    """Retrieves the 'value' attribute of the flag --name."""
    try:
      parsed = self.__dict__['__parsed']
    except KeyError:
      # May happen during pickle.load or copy.copy
      raise AttributeError(name)
    if not parsed:
      self._parse_flags()
    if name not in self.__dict__['__flags']:
      raise AttributeError(name)
    return self.__dict__['__flags'][name]

  def __setattr__(self, name, value):
    """Sets the 'value' attribute of the flag --name."""
    if not self.__dict__['__parsed']:
      self._parse_flags()
    self.__dict__['__flags'][name] = value
    self._assert_required(name)

  def _add_required_flag(self, item):
    self.__dict__['__required_flags'].add(item)

  def _assert_required(self, flag_name):
    if (flag_name not in self.__dict__['__flags'] or
        self.__dict__['__flags'][flag_name] is None):
      raise AttributeError('Flag --%s must be specified.' % flag_name)

  def _assert_all_required(self):
    for flag_name in self.__dict__['__required_flags']:
      self._assert_required(flag_name)


def _define_helper(flag_name, default_value, docstring, flagtype):
  """Registers 'flag_name' with 'default_value' and 'docstring'."""
  _global_parser.add_argument('--' + flag_name,
                              default=default_value,
                              help=docstring,
                              type=flagtype)


# Provides the global object that can be used to access flags.
FLAGS = _FlagValues()


def DEFINE_string(flag_name, default_value, docstring):
  """Defines a flag of type 'string'.

  Args:
    flag_name: The name of the flag as a string.
    default_value: The default value the flag should take as a string.
    docstring: A helpful message explaining the use of the flag.
  """
  _define_helper(flag_name, default_value, docstring, str)


def DEFINE_integer(flag_name, default_value, docstring):
  """Defines a flag of type 'int'.

  Args:
    flag_name: The name of the flag as a string.
    default_value: The default value the flag should take as an int.
    docstring: A helpful message explaining the use of the flag.
  """
  _define_helper(flag_name, default_value, docstring, int)


def DEFINE_boolean(flag_name, default_value, docstring):
  """Defines a flag of type 'boolean'.

  Args:
    flag_name: The name of the flag as a string.
    default_value: The default value the flag should take as a boolean.
    docstring: A helpful message explaining the use of the flag.
  """
  # Register a custom function for 'bool' so --flag=True works.
  def str2bool(v):
    return v.lower() in ('true', 't', '1')
  _global_parser.add_argument('--' + flag_name,
                              nargs='?',
                              const=True,
                              help=docstring,
                              default=default_value,
                              type=str2bool)

  # Add negated version, stay consistent with argparse with regard to
  # dashes in flag names.
  _global_parser.add_argument('--no' + flag_name,
                              action='store_false',
                              dest=flag_name.replace('-', '_'))


# The internal google library defines the following alias, so we match
# the API for consistency.
DEFINE_bool = DEFINE_boolean  # pylint: disable=invalid-name


def DEFINE_float(flag_name, default_value, docstring):
  """Defines a flag of type 'float'.

  Args:
    flag_name: The name of the flag as a string.
    default_value: The default value the flag should take as a float.
    docstring: A helpful message explaining the use of the flag.
  """
  _define_helper(flag_name, default_value, docstring, float)


def mark_flag_as_required(flag_name):
  """Ensures that flag is not None during program execution.
  
  It is recommended to call this method like this:
  
    if __name__ == '__main__':
      tf.flags.mark_flag_as_required('your_flag_name')
      tf.app.run()
  
  Args:
    flag_name: string, name of the flag to mark as required.
 
  Raises:
    AttributeError: if flag_name is not registered as a valid flag name.
      NOTE: The exception raised will change in the future. 
  """
  if _global_parser.get_default(flag_name) is not None:
    _logging.warn(
        'Flag %s has a non-None default value; therefore, '
        'mark_flag_as_required will pass even if flag is not specified in the '
        'command line!' % flag_name)
  FLAGS._add_required_flag(flag_name)


def mark_flags_as_required(flag_names):
  """Ensures that flags are not None during program execution.
  
  Recommended usage:
  
    if __name__ == '__main__':
      tf.flags.mark_flags_as_required(['flag1', 'flag2', 'flag3'])
      tf.app.run()
  
  Args:
    flag_names: a list/tuple of flag names to mark as required.

  Raises:
    AttributeError: If any of flag name has not already been defined as a flag.
      NOTE: The exception raised will change in the future.
  """
  for flag_name in flag_names:
    mark_flag_as_required(flag_name)


_allowed_symbols = [
    # We rely on gflags documentation.
    'DEFINE_bool',
    'DEFINE_boolean',
    'DEFINE_float',
    'DEFINE_integer',
    'DEFINE_string',
    'FLAGS',
    'mark_flag_as_required',
    'mark_flags_as_required',
]
remove_undocumented(__name__, _allowed_symbols)

二、tensorflow命令行参数源代码分析

该源文件定义了很多保护成员函数(单下划线)和私有成员函数(双下划线),其只允许我们引用如下符号:

_allowed_symbols = [
    # We rely on gflags documentation.
    'DEFINE_bool',
    'DEFINE_boolean',
    'DEFINE_float',
    'DEFINE_integer',
    'DEFINE_string',
    'FLAGS',
    'mark_flag_as_required',
    'mark_flags_as_required',
]
remove_undocumented(__name__, _allowed_symbols)

def DEFINE_string(flag_name, default_value, docstring):

def DEFINE_integer(flag_name, default_value, docstring):

def DEFINE_boolean(flag_name, default_value, docstring):

def DEFINE_float(flag_name, default_value, docstring):

这些函数形参一致,均可以定义不同类型的命令行参数,有字符串型、整型、布尔型及浮点型,flag_name是该参数的名字,default_value是未在命令行显示配置时的默认值,docstring则是help文档,在使用“–flag_name”后接“–help”可以获取到该help说明文档,示例如下。DEFINE_bool其实是DEFINE_boolean的另一个别名,所以不管定义哪个他们都是等价的,如果设定的参数是诸如(‘true’, ‘t’, ‘1’)字符串,那么该函数会自动将这类参数转化为布尔值。

FLAGS:
实际上他是一个类名,我们看到源代码有如下定义

# Provides the global object that can be used to access flags.
FLAGS = _FlagValues()

通过这个类别名,我们可以获取到该类定义的属性和成员,第三章我们再来好好的说一下它是怎么使用的。

def mark_flag_as_required(flag_name):

def mark_flags_as_required(flag_names):

这两个函数的作用是一样的,都是强制要求显示定义命令行参数,否则会报错!前者可以强制要求显示定义一个命令行参数,而后者则可以强制要求显示定义N个命令行参数,它是通过调用前者来实现的。如果你用这两个接口来检查命令行参数,但你code里却没有定义,那么运行时就一定会报错!

三、如何构造tensorflow命令行参数

构造tensorflow命令行参数还是蛮简单的,通常按照如下步骤即可(以下默认都执行了“import tensorflow as tf”):

第一步(获取可以class _FlagValues(object)成员函数的使用句柄):

flags = tf.app.flags

第二步(不是必须的,只是为了更安全的检查):

flags.mark_flags_as_required(flag_names)
或者
flags.mark_flag_as_required(flag_name)

第三步(定义不同类型的命令行参数):

flags.DEFINE_string(flag_name, default_value, docstring)
flags.DEFINE_integer(flag_name, default_value, docstring)
flags.DEFINE_boolean(flag_name, default_value, docstring)
flags.DEFINE_float(flag_name, default_value, docstring)

至此,tensorflow的命令行参数就构造完成了!

四、如何解析tensorflow命令行参数

解析tensorflow命令行参数更加简单,但是命令行参数名字一定要注意与构造时的一致,否则就会报错没有意义了,通常按照如下步骤解析

第一步(获取可以使用class _FlagValues(object)成员属性的使用句柄):

FLAGS = flags.FLAGS

第二步(直接获取已定义了的命令行参数,不同参数类型均可以按照如下方式获取):

flag_name = FLAGS.flag_name

至此,tensorflow的命令行参数就解析完成了!

五、举例说明tensorflow命令行参数

如下是一个tensorflow命令行参数构造和解析的使用例子,其命令行参数在这里没有实际意义,我们只是为了举例说明,以便观察其运行结果并加深对tensorflow命令行参数原理的理解!

import tensorflow as tf

flags = tf.app.flags # structure first

flags.mark_flags_as_required(['batch_size', 'learning_rate', 'data_dir']) # structure second

# structure third
#flags.DEFINE_integer("batch_size", 1000, "training batch size")
flags.DEFINE_float("learning_rate", 0.1, "training learning rate")
flags.DEFINE_string("data_dir", "/home/xsr-ai/study", "training data directory")
flags.DEFINE_boolean("use_gpu_device", "False", "use gpu device to training or not")

# parsing first
FLAGS = flags.FLAGS

# parsing  second
batch_size = FLAGS.batch_size
learning_rate = FLAGS.learning_rate
data_dir = FLAGS.data_dir
use_gpu_device = FLAGS.use_gpu_device

# finally validation
print("batch_size=%s" % batch_size)
print("learning_rate=%f" % learning_rate)
print("data_dir=%s" % data_dir)
print("use_gpu_device=%d" % use_gpu_device)

会出错
#flags.DEFINE_integer(“batch_size”, 1000, “training batch size”)
注释去掉即可
mark_flag_as_required,他会进行安全检查确保被他标志了的命令行参数都有显示定义!

成功运行

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值