slim.arg_scope()的使用

slim.arg_scope() 函数的使用  

     slim是一种轻量级的tensorflow库,可以使模型的构建,训练,测试都变得更加简单。在slim库中对很多常用的函数进行了定义,slim.arg_scope()是slim库中经常用到的函数之一。函数的定义如下;


@tf_contextlib.contextmanager

def arg_scope(list_ops_or_scope, **kwargs):

  """Stores the default arguments for the given set of list_ops.



  For usage, please see examples at top of the file.



  Args:

    list_ops_or_scope: List or tuple of operations to set argument scope for or

      a dictionary containing the current scope. When list_ops_or_scope is a

      dict, kwargs must be empty. When list_ops_or_scope is a list or tuple,

      then every op in it need to be decorated with @add_arg_scope to work.

    **kwargs: keyword=value that will define the defaults for each op in

              list_ops. All the ops need to accept the given set of arguments.



  Yields:

    the current_scope, which is a dictionary of {op: {arg: value}}

  Raises:

    TypeError: if list_ops is not a list or a tuple.

    ValueError: if any op in list_ops has not be decorated with @add_arg_scope.

  """

  if isinstance(list_ops_or_scope, dict):

    # Assumes that list_ops_or_scope is a scope that is being reused.

    if kwargs:

      raise ValueError('When attempting to re-use a scope by suppling a'

                       'dictionary, kwargs must be empty.')

    current_scope = list_ops_or_scope.copy()

    try:

      _get_arg_stack().append(current_scope)

      yield current_scope

    finally:

      _get_arg_stack().pop()

  else:

    # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.

    if not isinstance(list_ops_or_scope, (list, tuple)):

      raise TypeError('list_ops_or_scope must either be a list/tuple or reused'

                      'scope (i.e. dict)')

    try:

      current_scope = current_arg_scope().copy()

      for op in list_ops_or_scope:

        key_op = _key_op(op)

        if not has_arg_scope(op):

          raise ValueError('%s is not decorated with @add_arg_scope',

                           _name_op(op))

        if key_op in current_scope:

          current_kwargs = current_scope[key_op].copy()

          current_kwargs.update(kwargs)

          current_scope[key_op] = current_kwargs

        else:

          current_scope[key_op] = kwargs.copy()

      _get_arg_stack().append(current_scope)

      yield current_scope

    finally:

      _get_arg_stack().pop()

     如注释中所说,这个函数的作用是给list_ops中的内容设置默认值。但是每个list_ops中的每个成员需要用@add_arg_scope修饰才行。所以使用slim.arg_scope()有两个步骤:


 

1、使用@slim.add_arg_scope修饰目标函数    
2、用 slim.arg_scope()为目标函数设置默认参数.


     例如如下代码;首先用@slim.add_arg_scope修饰目标函数fun1(),然后利用slim.arg_scope()为它设置默认参数。


import tensorflow as tf

slim =tf.contrib.slim

 

@slim.add_arg_scope

def fun1(a=0,b=0):

    return (a+b)

 

with slim.arg_scope([fun1],a=10):

    x=fun1(b=30)

    print(x)

     运行结果为:

40

    平常所用到的slim.conv2d( ),slim.fully_connected( ),slim.max_pool2d( )等函数在他被定义的时候就已经添加了@add_arg_scope。以slim.conv2d( )为例;

 


@add_arg_scope

def convolution(inputs,

                num_outputs,

                kernel_size,

                stride=1,

                padding='SAME',

                data_format=None,

                rate=1,

                activation_fn=nn.relu,

                normalizer_fn=None,

                normalizer_params=None,

                weights_initializer=initializers.xavier_initializer(),

                weights_regularizer=None,

                biases_initializer=init_ops.zeros_initializer(),

                biases_regularizer=None,

                reuse=None,

                variables_collections=None,

                outputs_collections=None,

                trainable=True,

                scope=None):

 

     所以,在使用过程中可以直接slim.conv2d( )等函数设置默认参数。例如在下面的代码中,不做单独声明的情况下,slim.conv2d, slim.max_pool2d, slim.avg_pool2d三个函数默认的步长都设为1,padding模式都是'VALID'的。但是也可以在调用时进行单独声明。这种参数设置方式在构建网络模型时,尤其是较深的网络时,可以节省时间。

 


 with slim.arg_scope(

                [slim.conv2d, slim.max_pool2d, slim.avg_pool2d],stride = 1, padding = 'VALID'):

            net = slim.conv2d(inputs, 32, [3, 3], stride = 2, scope = 'Conv2d_1a_3x3')

            net = slim.conv2d(net, 32, [3, 3], scope = 'Conv2d_2a_3x3')

            net = slim.conv2d(net, 64, [3, 3], padding = 'SAME', scope = 'Conv2d_2b_3x3')

 

@修饰符     

 

     其实这种用法是python中常用到的。在python中@修饰符放在函数定义的上方,它将被修饰的函数作为参数,并返回修饰后的同名函数。形式如下;

@fun_a     #等价于fun_a(fun_b)
def fun_b():

@fun_a     #等价于fun_a(fun_b)def fun_b():

      这在本质上讲跟直接调用被修饰的函数没什么区别,但是有时候也有用处,例如在调用被修饰函数前需要输出时间信息,我们可以在@后方的函数中添加输出时间信息的语句,这样每次我们只需要调用@后方的函数即可。


def funs(fun,factor=20):

    x=fun()

    print(factor*x)

    

    

@funs     #等价funs(add(),fator=20)

def add(a=10,b=20):

    return(a+b)

 

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值