利用偏函数设计学习率下降方式

Python partial()

首先,我们需要简单了解下偏函数的作用:和装饰器一样,它可以扩展函数的功能,但又不完成等价于装饰器。通常应用的场景是当我们要频繁调用某个函数时,其中某些参数是已知的固定值,通常我们可以调用这个函数多次,但这样看上去似乎代码有些冗余,而偏函数的出现就是为了很少的解决这一个问题。举一个很简单的例子,比如我就想知道 100 加任意数的和是多少,通常我们的实现方式是这样的:

# 第一种做法:
def add(*args):
    return sum(args)

print(add(1, 2, 3) + 100)
print(add(5, 5, 5) + 100)

# 第二种做法
def add(*args):
    # 对传入的数值相加后,再加上100返回
    return sum(args) + 100

print(add(1, 2, 3))  # 106
print(add(5, 5, 5))  # 115 

看上面的代码,貌似也挺简单,也不是很费劲。但两种做法都会存在有问题:第一种,100这个固定值会返回出现,代码总感觉有重复;第二种,就是当我们想要修改 100 这个固定值的时候,我们需要改动 add 这个方法。下面我们来看下用 parital 怎么实现:

from functools import partial

def add(*args):
	return sum(args)

add_100  = partial(add, 100)
print(add_100(1, 2, 3))

add_200 = partial(add, 101)
print(add_100(3, 4, 5))

是不是很简单~

大概了解了偏函数的例子后,我们再来看一下偏函数的定义:

类func = functools.partial(func, *args, **keywords)

我们可以看到,partial 一定接受三个参数,从之前的例子,我们也能大概知道这三个参数的作用,简单介绍下:

func: 需要被扩展的函数,返回的函数其实是一个类 func 的函数
*args: 需要被固定的位置参数
**kwargs: 需要被固定的关键字参数

如果在原来的函数 func 中关键字不存在,将会扩展,如果存在,则会覆盖

用一个简单的包含位置参数和关键字参数的示例代码来说明用法:

同样是刚刚求和的代码,不同的是加入的关键字参数

def add(*args, **kwargs):
    # 打印位置参数
    for n in args:
        print(n)
    print("-"*20)
    # 打印关键字参数
    for k, v in kwargs.items():
       print('%s:%s' % (k, v))
    # 暂不做返回,只看下参数效果,理解 partial 用法

普通调用

add(1, 2, 3, v1=10, v2=20)
"""
1
2
3
--------------------
v1:10
v2:20
"""

# partial
add_partial = partial(add, 10, k1=10, k2=20)
add_partial(1, 2, 3, k3=20)
"""
10
1
2
3
--------------------
k1:10
k2:20
k3:20
"""

add_partial(1, 2, 3, k1=20)
"""
10
1
2
3
--------------------
k1:20
k2:20
"""

Detection 应用代码

lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)

def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
    def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
        if iters <= warmup_total_iters:
            # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
            lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
        elif iters >= total_iters - no_aug_iter:
            lr = min_lr
        else:
            lr = min_lr + 0.5 * (lr - min_lr) * (
                1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
            )
        return lr

    def step_lr(lr, decay_rate, step_size, iters):
        if step_size < 1:
            raise ValueError("step_size must above 1.")
        n       = iters // step_size
        out_lr  = lr * decay_rate ** n
        return out_lr

    if lr_decay_type == "cos":
        warmup_total_iters  = min(max(warmup_iters_ratio * total_iters, 1), 3)
        warmup_lr_start     = max(warmup_lr_ratio * lr, 1e-6)
        no_aug_iter         = min(max(no_aug_iter_ratio * total_iters, 1), 15)
        # 偏funcation
        func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
    else:
        decay_rate  = (min_lr / lr) ** (1 / (step_num - 1))
        step_size   = total_iters / step_num
        func = partial(step_lr, lr, decay_rate, step_size)

    return func

def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
    lr = lr_scheduler_func(epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浪子私房菜

给小强一点爱心呗

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值