Pytorch学习笔记:LambdaLR——自定义学习率变化器

Pytorch学习笔记:LambdaLR——自定义学习率变化器

torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose='deprecated')

功能:

  将每个参数的学习率设置为初始的lr乘以一个权重系数factor,用于调整学习率大小,其中权重系数factor由函数lr_lambda得到,这里可以为每个层设置不同的学习率调整策略。

输入:

  • optimizer:优化器;
  • lr_lambda:给定epoch或者,传入函数或list列表;
  • last_epoch:当前的epoch,默认-1;
  • verbose:如果设为True,则每次学习率更新都会输出一条消息(即将弃用,查看学习率可通过调用get_last_lr()实现);

常用方法:

  • get_last_lr():返回当前的学习率

  • state_dict():提取__dict__中的数据(不包括optimizer),如果lr_lambda是一个可调用的对象时,可以被提取,如果是函数或者lambda时,则不会被提取,会得到None

  • load_state_dict(state_dict):加载参数;

代码案例

  对模型中所有参数都使用相同的学习率调整策略,学习率权重因子计算方法如下:
l r = α e p o c h ∗ b a s e _ l r lr=\alpha^{epoch} * base\_lr lr=αepochbase_lr

from torch.optim.lr_scheduler import LambdaLR
from torch.optim import SGD
from torchvision import models


def lambda_lr(epoch, alpha=0.99):
    return alpha ** epoch


model = models.resnet50()
optimizer = SGD(model.parameters(), lr=1e-3)
our_scheduler = LambdaLR(optimizer, lambda_lr)
last_lr = our_scheduler.get_last_lr()

for i in range(100):
    our_scheduler.step()
    last_lr = our_scheduler.get_last_lr()
    print(last_lr)

  对不同的参数层使用不同的学习率调整策略,这里对resnet50的特征提取层和全连接层使用不同的学习率下降策略,其中特征提取层下降速度要快于全连接层。

  首先在定义优化器时,需要将两组参数以不同的键值对传入优化器中,在定义lr_lambda时需要传入两种变化策略(以列表格式传入),注意顺序是一一对应的

from torch.optim.lr_scheduler import LambdaLR
from torch.optim import SGD
from torchvision import models


def fc_lambda_lr(epoch, alpha=0.99):
    return alpha ** epoch


def feature_lambda_lr(epoch, alpha=0.88):
    return alpha ** epoch


model = models.resnet50()
feature_params = []
fc_params = []
for name, param in model.named_parameters():
    if 'fc' in name:
        fc_params.append(param)
    else:
        feature_params.append(param)

optimizer = SGD([
    {'params': feature_params},
    {'params': fc_params}
], lr=1e-3)

our_scheduler = LambdaLR(optimizer, lr_lambda=[feature_lambda_lr, fc_lambda_lr])
last_lr = our_scheduler.get_last_lr()

for i in range(100):
    our_scheduler.step()
    last_lr = our_scheduler.get_last_lr()
    print(last_lr)

官方文档

LambdaLR:https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.LambdaLR.html#lambdalr

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值