原帖:原帖
一、认识tensorflow命令行参数
在深度学习训练中,我们常常需要动态的配置诸如batch size、learning rate、epoch、kernel size等等超参数,同时在分布式训练时为了区别运行不同的代码,我们也需要配置一个参数用以运行不同代码。那么有无一种比较合适的可以动态配置的方法呢?答案是肯定的,一种是使用python的argparse库,另外一种是使用tensorflow的tf.app.flags组件,今天我们要讲的是后者。
tf.app.flags其实是tensorflow定义的一个类,本质上是基于argparse再封装更友好直观的一个类库,其主要源代码位于“tensorflow/tensorflow/python/platform/flags.py”,flags.py定义了如何构造和解析命令行参数,通过分析源代码及其注解,我们可以很直观的明了其原理,其源代码如下:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the flags interface."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse as _argparse
from tensorflow.python.platform import tf_logging as _logging
from tensorflow.python.util.all_util import remove_undocumented
_global_parser = _argparse.ArgumentParser()
# pylint: disable=invalid-name
class _FlagValues(object):
"""Global container and accessor for flags and their values."""
def __init__(self):
self.__dict__['__flags'] = {}
self.__dict__['__parsed'] = False
self.__dict__['__required_flags'] = set()
def _parse_flags(self, args=None):
result, unparsed = _global_parser.parse_known_args(args=args)
for flag_name, val in vars(result).items():
self.__dict__['__flags'][flag_name] = val
self.__dict__['__parsed'] = True
self._assert_all_required()
return unparsed
def __getattr__(self, name):
"""Retrieves the 'value' attribute of the flag --name."""
try:
parsed = self.__dict__['__parsed']
except KeyError:
# May happen during pickle.load or copy.copy
raise AttributeError(name)
if not parsed:
self._parse_flags()
if name not in self.__dict__['__flags']:
raise AttributeError(name)
return self.__dict__['__flags'][name]
def __setattr__(self, name, value):
"""Sets the 'value' attribute of the flag --name."""
if not self.__dict__['__parsed']:
self._parse_flags()
self.__dict__['__flags'][name] = value
self._assert_required(name)
def _add_required_flag(self, item):
self.__dict__['__required_flags'].add(item)
def _assert_required(self, flag_name):
if (flag_name not in self.__dict__['__flags'] or
self.__dict__['__flags'][flag_name] is None):
raise AttributeError('Flag --%s must be specified.' % flag_name)
def _assert_all_required(self):
for flag_name in self.__dict__['__required_flags']:
self._assert_required(flag_name)
def _define_helper(flag_name, default_value, docstring, flagtype):
"""Registers 'flag_name' with 'default_value' and 'docstring'."""
_global_parser.add_argument('--' + flag_name,
default=default_value,
help=docstring,
type=flagtype)
# Provides the global object that can be used to access flags.
FLAGS = _FlagValues()
def DEFINE_string(flag_name, default_value, docstring):
"""Defines a flag of type 'string'.
Args:
flag_name: The name of the flag as a string.
default_value: The default value the flag should take as a string.
docstring: A helpful message explaining the use of the flag.
"""
_define_helper(flag_name, default_value, docstring, str)
def DEFINE_integer(flag_name, default_value, docstring):
"""Defines a flag of type 'int'.
Args:
flag_name: The name of the flag as a string.
default_value: The default value the flag should take as an int.
docstring: A helpful message explaining the use of the flag.
"""
_define_helper(flag_name, default_value, docstring, int)
def DEFINE_boolean(flag_name, default_value, docstring):
"""Defines a flag of type 'boolean'.
Args:
flag_name: The name of the flag as a string.
default_value: The default value the flag should take as a boolean.
docstring: A helpful message explaining the use of the flag.
"""
# Register a custom function for 'bool' so --flag=True works.
def str2bool(v):
return v.lower() in ('true', 't', '1')
_global_parser.add_argument('--' + flag_name,
nargs='?',
const=True,
help=docstring,
default=default_value,
type=str2bool)
# Add negated version, stay consistent with argparse with regard to
# dashes in flag names.
_global_parser.add_argument('--no' + flag_name,
action='store_false',
dest=flag_name.replace('-', '_'))
# The internal google library defines the following alias, so we match
# the API for consistency.
DEFINE_bool = DEFINE_boolean # pylint: disable=invalid-name
def DEFINE_float(flag_name, default_value, docstring):
"""Defines a flag of type 'float'.
Args:
flag_name: The name of the flag as a string.
default_value: The default value the flag should take as a float.
docstring: A helpful message explaining the use of the flag.
"""
_define_helper(flag_name, default_value, docstring, float)
def mark_flag_as_required(flag_name):
"""Ensures that flag is not None during program execution.
It is recommended to call this method like this:
if __name__ == '__main__':
tf.flags.mark_flag_as_required('your_flag_name')
tf.app.run()
Args:
flag_name: string, name of the flag to mark as required.
Raises:
AttributeError: if flag_name is not registered as a valid flag name.
NOTE: The exception raised will change in the future.
"""
if _global_parser.get_default(flag_name) is not None:
_logging.warn(
'Flag %s has a non-None default value; therefore, '
'mark_flag_as_required will pass even if flag is not specified in the '
'command line!' % flag_name)
FLAGS._add_required_flag(flag_name)
def mark_flags_as_required(flag_names):
"""Ensures that flags are not None during program execution.
Recommended usage:
if __name__ == '__main__':
tf.flags.mark_flags_as_required(['flag1', 'flag2', 'flag3'])
tf.app.run()
Args:
flag_names: a list/tuple of flag names to mark as required.
Raises:
AttributeError: If any of flag name has not already been defined as a flag.
NOTE: The exception raised will change in the future.
"""
for flag_name in flag_names:
mark_flag_as_required(flag_name)
_allowed_symbols = [
# We rely on gflags documentation.
'DEFINE_bool',
'DEFINE_boolean',
'DEFINE_float',
'DEFINE_integer',
'DEFINE_string',
'FLAGS',
'mark_flag_as_required',
'mark_flags_as_required',
]
remove_undocumented(__name__, _allowed_symbols)
二、tensorflow命令行参数源代码分析
该源文件定义了很多保护成员函数(单下划线)和私有成员函数(双下划线),其只允许我们引用如下符号:
_allowed_symbols = [
# We rely on gflags documentation.
'DEFINE_bool',
'DEFINE_boolean',
'DEFINE_float',
'DEFINE_integer',
'DEFINE_string',
'FLAGS',
'mark_flag_as_required',
'mark_flags_as_required',
]
remove_undocumented(__name__, _allowed_symbols)
def DEFINE_string(flag_name, default_value, docstring):
def DEFINE_integer(flag_name, default_value, docstring):
def DEFINE_boolean(flag_name, default_value, docstring):
def DEFINE_float(flag_name, default_value, docstring):
这些函数形参一致,均可以定义不同类型的命令行参数,有字符串型、整型、布尔型及浮点型,flag_name是该参数的名字,default_value是未在命令行显示配置时的默认值,docstring则是help文档,在使用“–flag_name”后接“–help”可以获取到该help说明文档,示例如下。DEFINE_bool其实是DEFINE_boolean的另一个别名,所以不管定义哪个他们都是等价的,如果设定的参数是诸如(‘true’, ‘t’, ‘1’)字符串,那么该函数会自动将这类参数转化为布尔值。
FLAGS:
实际上他是一个类名,我们看到源代码有如下定义
# Provides the global object that can be used to access flags.
FLAGS = _FlagValues()
通过这个类别名,我们可以获取到该类定义的属性和成员,第三章我们再来好好的说一下它是怎么使用的。
def mark_flag_as_required(flag_name):
def mark_flags_as_required(flag_names):
这两个函数的作用是一样的,都是强制要求显示定义命令行参数,否则会报错!前者可以强制要求显示定义一个命令行参数,而后者则可以强制要求显示定义N个命令行参数,它是通过调用前者来实现的。如果你用这两个接口来检查命令行参数,但你code里却没有定义,那么运行时就一定会报错!
三、如何构造tensorflow命令行参数
构造tensorflow命令行参数还是蛮简单的,通常按照如下步骤即可(以下默认都执行了“import tensorflow as tf”):
第一步(获取可以class _FlagValues(object)成员函数的使用句柄):
flags = tf.app.flags
第二步(不是必须的,只是为了更安全的检查):
flags.mark_flags_as_required(flag_names)
或者
flags.mark_flag_as_required(flag_name)
第三步(定义不同类型的命令行参数):
flags.DEFINE_string(flag_name, default_value, docstring)
flags.DEFINE_integer(flag_name, default_value, docstring)
flags.DEFINE_boolean(flag_name, default_value, docstring)
flags.DEFINE_float(flag_name, default_value, docstring)
至此,tensorflow的命令行参数就构造完成了!
四、如何解析tensorflow命令行参数
解析tensorflow命令行参数更加简单,但是命令行参数名字一定要注意与构造时的一致,否则就会报错没有意义了,通常按照如下步骤解析
第一步(获取可以使用class _FlagValues(object)成员属性的使用句柄):
FLAGS = flags.FLAGS
第二步(直接获取已定义了的命令行参数,不同参数类型均可以按照如下方式获取):
flag_name = FLAGS.flag_name
至此,tensorflow的命令行参数就解析完成了!
五、举例说明tensorflow命令行参数
如下是一个tensorflow命令行参数构造和解析的使用例子,其命令行参数在这里没有实际意义,我们只是为了举例说明,以便观察其运行结果并加深对tensorflow命令行参数原理的理解!
import tensorflow as tf
flags = tf.app.flags # structure first
flags.mark_flags_as_required(['batch_size', 'learning_rate', 'data_dir']) # structure second
# structure third
#flags.DEFINE_integer("batch_size", 1000, "training batch size")
flags.DEFINE_float("learning_rate", 0.1, "training learning rate")
flags.DEFINE_string("data_dir", "/home/xsr-ai/study", "training data directory")
flags.DEFINE_boolean("use_gpu_device", "False", "use gpu device to training or not")
# parsing first
FLAGS = flags.FLAGS
# parsing second
batch_size = FLAGS.batch_size
learning_rate = FLAGS.learning_rate
data_dir = FLAGS.data_dir
use_gpu_device = FLAGS.use_gpu_device
# finally validation
print("batch_size=%s" % batch_size)
print("learning_rate=%f" % learning_rate)
print("data_dir=%s" % data_dir)
print("use_gpu_device=%d" % use_gpu_device)
会出错
#flags.DEFINE_integer(“batch_size”, 1000, “training batch size”)
注释去掉即可
mark_flag_as_required,他会进行安全检查确保被他标志了的命令行参数都有显示定义!