# 一些必要的库和参数
import torch
import torch.nn as nn
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np
以SGD为例
model = models.resnet18()
init_lr = 0.1
optimizer = torch.optim.SGD(model.parameters(), init_lr)
# 查看学习率
for param_group in optimizer.param_groups:
print(param_group['lr'])
# 0.1
1. 官方例子里是如下的自定义函数方式,以最常用的调整策略StepLR为例,每隔一定轮数进行改变
# Reference:https://github.com/pytorch/examples/blob/master/imagenet/main.py
# 每30轮学习率乘以0.1
def adjust_learning_rate(optimizer, epoch, init_lr):
"""
optimizer: 优化器
epoch: 训练轮数,也可以根据需要加入其它参数
init_lr:初始学习率,也可以设置为全局变量
"""
lr = init_lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
total_epoch = 100
lrs = []
# 每一轮调用函数即可
for epoch in range(total_epoch):
adjust_learning_rate(optimizer, epoch, init_lr)
lrs.append(optimizer.param_groups[0]['lr'])
plt.plot(range(total_epoch), lrs)
plt.title('adjustLR')
plt.savefig('adjustLR.jpg', bbox_inches='tight')
这种方法可以很方便根据自己的逻辑获得想要的学习率变化策略,可以很复杂,也可以很简单。
2. lr_scheduler
PyTorch中提供了多种预设的学习率策略,都包含在torch.optim.lr_scheduler
,详细见 Docs 。
同理,以 StepLR 为例
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
total_epoch = 100
lrs = []
for epoch in range(total_epoch):
lrs.append(optimizer.param_groups[0]['lr'])
# 调用step()即更新学习率
scheduler.step()
plt.plot(range(total_epoch), lrs)
plt.title('StepLR')
可以看到,两种方式的StepLR效果是一样的。
3. 带warmup的学习率调整
3.1 自定义函数
def adjust_learning_rate(optimizer, warm_up_step, epoch, init_lr):
if epoch < warm_up_step:
lr = (epoch + 1) / warm_up_step * init_lr
else:
lr = init_lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
3.2 LambdaLR
lr_scheduler中也有一项自定义的学习率调整方法,通过构造匿名函数来实现
lambda_ = lambda epoch: (epoch + 1) / warm_up_step if epoch < warm_up_step else 0.1 ** (epoch // 30)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda_)
需要注意的是,自定义函数的方式是直接对优化器中的学习率赋值,而LambdaLR是学习率的权重!
3.3 余弦变换
余弦变换也是常用的学习率调整策略之一,跟steplr可以达到差不多的效果,但是从训练图像上看会更平稳一些。
lambda_ = lambda epoch: (epoch + 1) / warm_up_step if epoch < warm_up_step else 0.5 * (np.cos((epoch - warm_up_step) / (total_epoch - warm_up_step) * np.pi) + 1)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda_)
4. Other
其他策略如:余弦退火,指数变换,正弦变换,学习率重启等。