PyTorch学习—14.PyTorch中的学习率调整策略

本文介绍了PyTorch中六种学习率调整策略:StepLR、MultiStepLR、ExponentialLR、CosineAnnealingLR、ReduceLRonPlateau和LambdaLR。学习率调整对于训练过程至关重要,初期设置较大以快速探索,后期减小以精细优化。通过不同的调整策略,如按步长、指数衰减或监控指标变化,可以有效改善模型的训练效果。
摘要由CSDN通过智能技术生成

一、为什么要调整学习率

  在参数更新的过程中,开始时学习率一般给的比较大,后期学习率会给的小一些。这样才会使得总的损失小一些。

二、PyTorch的六种学习率调整策略

  PyTorch的六种学习率调整策略都继承于class _LRScheduler(object)这个基类,所以我们首先介绍这个基类

class _LRScheduler(object):
	def __init__(self, optimizer, last_epoch=-1):
		...

	def get_lr(self):
		...

	def step(self):
		...

主要属性:

  • optimizer:关联的优化器
  • last_epoch:记录epoch数
  • base_lrs:记录初始学习率

主要方法:

  • get_lr():虚函数,计算下一个epoch的学习率
  • step():更新下一个epoch的学习率
1.StepLR
lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)

功能:等间隔调整学习率
主要参数:

  • step_size:调整间隔数
  • gamma:调整系数
    调整方式:lr = lr * gamma
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200

weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)

# 设置学习率下降策略
scheduler_lr = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)  

lr_list, epoch_list = list(), list()
for epoch in range(max_epoch):
    # 获取当前lr,新版本用 get_last_lr()函数,旧版本用get_lr()函数,具体看UserWarning
    lr_list.append(scheduler_lr.get_lr())
    epoch_list.append(epoch)

    for i in range(iteration):

        loss = torch.pow((weights - target), 2)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

    scheduler_lr.step()

plt.plot(epoch_list, lr_list, label="Step LR Scheduler")
plt.xlabel("Epoch")
plt.ylabel("Learning rate")
plt.legend()
plt.show()

在这里插入图片描述

2.MultiStepLR
lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)

功能:按给定间隔调整学习率
主要参数:

  • milestones:设定调整时刻数
  • gamma:调整系数
    调整方式:lr = lr * gamma
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200

weights = torch.randn((1), requires_grad
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值