函数的定义如下
@tf_contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):
"""Stores the default arguments for the given set of list_ops.
For usage, please see examples at top of the file.
Args:
list_ops_or_scope: List or tuple of operations to set argument scope for or
a dictionary containing the current scope. When list_ops_or_scope is a
dict, kwargs must be empty. When list_ops_or_scope is a list or tuple,
then every op in it need to be decorated with @add_arg_scope to work.
**kwargs: keyword=value that will define the defaults for each op in
list_ops. All the ops need to accept the given set of arguments.
Yields:
the current_scope, which is a dictionary of {op: {arg: value}}
Raises:
TypeError: if list_ops is not a list or a tuple.
ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
"""
if isinstance(list_ops_or_scope, dict):
# Assumes that list_ops_or_scope is a scope that is being reused.
if kwargs:
raise ValueError('When attempting to re-use a scope by suppling a'
'dictionary, kwargs must be empty.')
current_scope = list_ops_or_scope.copy()
try:
_get_arg_stack().append(current_scope)
yield current_scope
finally:
_get_arg_stack().pop()
else:
# Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
if not isinstance(list_ops_or_scope, (list, tuple)):
raise TypeError('list_ops_or_scope must either be a list/tuple or reused'
'scope (i.e. dict)')
try:
current_scope = current_arg_scope().copy()
for op in list_ops_or_scope:
key_op = _key_op(op)
if not has_arg_scope(op):
raise ValueError('%s is not decorated with @add_arg_scope',
_name_op(op))
if key_op in current_scope:
current_kwargs = current_scope[key_op].copy()
current_kwargs.update(kwargs)
current_scope[key_op] = current_kwargs
else:
current_scope[key_op] = kwargs.copy()
_get_arg_stack().append(current_scope)
yield current_scope
finally:
_get_arg_stack().pop()
如注释中所说,这个函数的作用是给list_ops中的内容设置默认值,但是每个list_ops中的每个成员需要用@add_arg_scope修饰才行,所以使用slim.arg_scope()有两个步骤
1。使用@slim.add_arg_scope修饰目标函数
2、用slim.arg_scope()为目标函数设置默认参数
import tensorflow as tf
slim = tf.contrib.slim
@slim.add_arg_scope
def fun1(a=0, b=0):
return (a + b)
with slim.arg_scope([fun1], a=10):
x = fun1(b=30)
print(x)