Pytorch lr_scheduler 调整学习率

Pytorch lr_scheduler 调整学习率

背景

上篇文章连接

在运行 VGG 代码的时候有这么几行代码:

# 定义模型进行训练
model = VGG16()
# model.load_state_dict(torch.load('./my-VGG16.pth'))
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=5e-3)
loss_func = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1) # todo 了解这个操作!

定义优化器,损失函数我都知道要做,但是突然出现一个 scheduler 我就看不懂是什么了。

这里来了解一下

在PyTorch中,lr_scheduler 是用于调整学习率(Learning Rate)的一个模块,它可以在训练过程中动态地改变学习率,有助于改善模型的训练效果,避免陷入局部最优解,或者加速收敛过程。StepLRlr_scheduler 模块中的一个类,它按照固定的步长(step_size)来降低学习率。

在你给出的代码片段中:

python复制代码

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1)
  • optimizer:这是你需要调整学习率的优化器对象。在你的例子中,它是通过 optim.SGD 创建的,用于VGG16模型的参数优化。
  • step_size:这个参数定义了学习率更新的周期。在你的例子中,step_size=5 意味着每经过5个epoch(训练周期),学习率就会更新一次。
  • gamma:这个参数定义了学习率更新的乘法因子。在你的例子中,gamma=0.4 意味着每次学习率更新时,新的学习率将是旧学习率的0.4倍。这有助于在训练过程中逐渐减小学习率,以便在接近最优解时进行更细致的调整。
  • last_epoch:这个参数用于指示在调用 scheduler.step() 之前,已经模拟了多少个epoch的更新。在初次设置时,如果希望从头开始计算(即从epoch 0开始),通常将其设置为 -1。这样,第一次调用 scheduler.step() 时,将认为是从epoch 0结束后的状态开始,从而按照 step_size 的设置来更新学习率。

在训练循环中,你需要在每个epoch结束后调用 scheduler.step() 来更新学习率。例如:

for epoch in range(num_epochs):  
    # 训练代码...  
    # 在每个epoch结束后更新学习率  
    scheduler.step()

通过这种方式,你可以根据训练过程的需要,灵活地调整学习率,以期获得更好的训练效果。

实际调用:

# 定义训练步骤
total_times = 40
total = 0
accuracy_rate = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for epoch in range(total_times):
    model.train()
    model.to(device)
    running_loss = 0.0
    total_correct = 0
    total_trainset = 0
    print("epoch: ",epoch)
    for i, (data,labels) in enumerate(train_loader):
        data = data.to(device)
        outputs = model(data).to(device)
        labels = labels.to(device)
        loss = loss_func(outputs,labels).to(device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _,pred = outputs.max(1)
        correct = (pred == labels).sum().item()
        total_correct += correct
        total_trainset += data.shape[0]
        if i % 100 == 0 and i > 0:
            print(f"正在进行第{i}次训练, running_loss={running_loss}".format(i, running_loss))
            running_loss = 0.0
    test()
    scheduler.step()

其他调整学习率的方法

在PyTorch中,除了StepLR之外,还有多种其他方法用于调整学习率。这些方法可以帮助在训练过程中更灵活地控制学习率,以适应不同的训练需求和数据集特性。以下是一些常见的PyTorch学习率调整方法:

  1. MultiStepLR:
    • 功能:按给定间隔调整学习率。
    • 主要参数:
      • milestones:一个列表,包含需要调整学习率的epoch数。在每个milestones指定的epoch结束时,学习率会按照给定的gamma进行调整。
      • gamma:调整系数,与StepLR中的gamma相同,用于计算新的学习率。
    • 使用场景:当你知道在特定的epoch点需要调整学习率时,可以使用此方法。
  2. ExponentialLR:
    • 功能:按指数衰减调整学习率。
    • 主要参数:
      • gamma:指数的底,通常设置为小于1的数(如0.9),用于计算学习率的衰减。
    • 使用场景:当希望学习率随着训练的进行逐渐减小,且减小速度呈指数级变化时,可以使用此方法。
  3. CosineAnnealingLR:
    • 功能:余弦周期调整学习率。
    • 主要参数:
      • T_max:下降周期,表示余弦周期的一半。学习率会在每个T_max周期内按照余弦函数变化。
      • eta_min:学习率下限,学习率变化过程中不会低于此值。
    • 使用场景:当希望学习率在一个周期内先下降后上升,模拟退火过程时,可以使用此方法。
  4. ReduceLROnPlateau:
    • 功能:监控某个指标(如loss或accuracy),当指标不再改善时调整学习率。
    • 主要参数:
      • mode'min''max',表示监控的指标是应该最小化还是最大化。
      • factor:调整系数,用于计算新的学习率。
      • patience:“耐心”参数,表示在调整学习率之前,指标可以接受连续多少次不改善。
      • cooldown:“冷却时间”,在调整学习率后,暂停监控一段时间。
      • min_lr:学习率下限。
    • 使用场景:当你想根据模型的实际表现(而非固定的epoch数)来调整学习率时,可以使用此方法。
  5. LambdaLR
    • 功能:使用自定义的lambda函数来调整学习率。
    • 主要参数:
      • lr_lambda:一个函数或函数列表,用于计算新的学习率。如果传入函数列表,则列表中的每个函数都会独立地应用于每个参数组的学习率。
    • 使用场景:当需要实现更复杂的学习率调整策略时,可以使用此方法。
  6. WarmupLR(注意:这不是PyTorch官方直接提供的一个类,但可以通过自定义实现):
    • 功能:在训练初期逐渐增加学习率,以达到预热模型的效果。
    • 实现方式:可以通过自定义一个lr_scheduler或使用LambdaLR结合预热函数来实现。
    • 使用场景:当模型在训练初期容易因为学习率过大而不稳定时,可以使用此方法。

这些学习率调整方法各有特点,适用于不同的训练场景和需求。在实际应用中,可以根据数据集的特性、模型的复杂度以及训练目标来选择合适的调整方法。

  • 14
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

xwhking

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值