更详细用法建议查看pytorch官方文档,如有错误,恳请指出
文章目录
前言
初学者,记录学习率调整的几种方法和实现。
提示:以下是本篇文章正文内容,下面案例可供参考
一、手动设置学习率
import torch
import matplotlib.pyplot as plt
from torch.optim import *
import torch.nn as nn
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 定义网络
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.fc = nn.Linear(1, 10)
def forward(self, x):
return self.fc(x)
model = net()
LR = 0.01
optimizer = Adam(model.parameters(), lr=LR)
lr_list = [] # 方便打印学习率曲线
for epoch in range(100):
if epoch % 5 == 0:
for p in optimizer.param_groups:
p['lr'] *= 0.9 # 每5个epoch衰减1次
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100), lr_list, color='r')
plt.title('手动设置')
plt.xlabel('epoch')
plt.ylabel('lr')
plt.show()
torch.optim.lr_scheduler 里有许多自动调整学习的方法
二、用 lamda 函数作为学习率的乘积因子更新学习率
torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)
- optimizer 传入的优化器
- lr_lambda lamda 表达式
- last_epoch 不用管 (默认-1,和训练中断,接着训练的情况有关)
- verbose 设为True,每更新学习率自动打印学习率。 Adjusting learning rate of group 0 to 10e-4
defaule=False
lr_list = []
model = net()
LR = 0.01
optimizer = Adam(model.parameters(), lr=LR)
lambda1 = lambda epoch: 1 / (epoch + 1)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
for epoch in range(100):
scheduler.step() # 注意,学习率的更新需放到优化器更新后
# lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
lr_list.append(scheduler.get_last_lr()) # 更简洁
plt.plot(range(100), lr_list, color='r')
plt.title('LambdaLR')
plt.xlabel('epoch')
plt.ylabel('lr')
plt.show()
print(lr_list[0], lr_list[1], lr_list[2]) # 0.005 0.003333333333333333 0.0025
三、用 StepLR 固定步长的学习率衰减
torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)
- optimizer
- step_size 更新步长,多少个
epoch
更新一次 - gamma 固定了的衰减因子
- last_epoch
lr_list = []
model = net()
LR = 0.01
optimizer = Adam(model.parameters(), lr=LR)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)
for epoch in range(100):
scheduler.step()
# lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
lr_list.append(scheduler.get_last_lr())
plt.plot(range(100), lr_list, color='r')
plt.title('StepLR')
plt.xlabel('epoch')
plt.ylabel('lr')
plt.show()
四、用 MultiStepLR 控制衰退时机
torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False)
- optimizer
- milestones 衰退列表,需升序排列,如[20,50] ,第20、50
epoch
时更新LR
- gamma
- last_epoch
- verbose
lr_list = []
model = net()
LR = 0.01
optimizer = Adam(model.parameters(), lr=LR)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[20, 80], gamma=0.9)
for epoch in range(100):
scheduler.step() # 第20epoch和第80个epoch时学习率衰减
lr_list.append(scheduler.get_last_lr())
plt.plot(range(100), lr_list, color='r')
plt.xlabel('epoch')
plt.ylabel('lr')
plt.title('MultiStepLR')
plt.show()
五、ReduceLROnPlateau 根据训练中某写些指标动态调节学习率
ReduceLROnPlateau(optimizer, mode='min',factor=0.1
, patience=10, threshold=0.0001,
threshold_mode='rel', cooldown=0,
min_lr=0, eps=1e-08,verbose=False)
当一个指标停止改进时,降低learning rate
。一旦学习停滞不前,模型通常会从降低2-10倍的learning rate
中受益。scheduler
会读取一个指标,如果在 patience
个epoch里内没有看到任何改善,learning rate
就会降低。
- optimizer 优化器
- mode 有
min、max
模式,在mode=min
时,监测的指标停止减小时,lr
被降低。mode=max
时,指标停止上升升时,lr被降低。default=min
- factor 即gamma。
new_lr = lr * factor
.default=0.1
。 - patience 如
patience = 2
,当连续2次指标没改善时,第3次如果还没改善,则降低学习率default=10
- threshold
- threshold_mode
- cooldown
- min_lr 学习率的下限
default=0
- eps 如果新旧
lr
之间的差异小于eps
,更新将被忽略default=1e-8
- verbose 每次更新都会向
stdout
打印一条信息。default=False
# 官方示例
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):
train(...)
val_loss = validate(...)
# Note that step should be called after validate()
scheduler.step(val_loss) # 监测val_loss
ExponentialLR、CosineAnnealingLR 不常用
六、ExponentialLR
torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False)
每训练一个epoch,学习率衰减一次
七、CosineAnnealingLR
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False)
按照余弦波形的衰减周期来更新学习率,前半个周期从最大值降到最小值,后半个周期从最小值升到最大值