文章目录
一、为什么要调整学习率
在参数更新的过程中,开始时学习率一般给的比较大,后期学习率会给的小一些。这样才会使得总的损失小一些。
二、PyTorch的六种学习率调整策略
PyTorch的六种学习率调整策略都继承于class _LRScheduler(object)
这个基类,所以我们首先介绍这个基类
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1):
...
def get_lr(self):
...
def step(self):
...
主要属性:
- optimizer:关联的优化器
- last_epoch:记录epoch数
- base_lrs:记录初始学习率
主要方法:
- get_lr():虚函数,计算下一个epoch的学习率
- step():更新下一个epoch的学习率
1.StepLR
lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)
功能:等间隔调整学习率
主要参数:
- step_size:调整间隔数
- gamma:调整系数
调整方式:lr = lr * gamma
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)
LR = 0.1
iteration = 10
max_epoch = 200
weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))
optimizer = optim.SGD([weights], lr=LR, momentum=0.9)
# 设置学习率下降策略
scheduler_lr = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
lr_list, epoch_list = list(), list()
for epoch in range(max_epoch):
# 获取当前lr,新版本用 get_last_lr()函数,旧版本用get_lr()函数,具体看UserWarning
lr_list.append(scheduler_lr.get_lr())
epoch_list.append(epoch)
for i in range(iteration):
loss = torch.pow((weights - target), 2)
loss.backward()
optimizer.step()
optimizer.zero_grad()
scheduler_lr.step()
plt.plot(epoch_list, lr_list, label="Step LR Scheduler")
plt.xlabel("Epoch")
plt.ylabel("Learning rate")
plt.legend()
plt.show()
2.MultiStepLR
lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)
功能:按给定间隔调整学习率
主要参数:
- milestones:设定调整时刻数
- gamma:调整系数
调整方式:lr = lr * gamma
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)
LR = 0.1
iteration = 10
max_epoch = 200
weights = torch.randn((1), requires_grad