小白学Pytorch系列–Torch.optim API Scheduler(3)
torch.optim.lr_scheduler
提供了几种根据时期数量调整学习率的方法。
torch.optim.lr_scheduler.ReduceLROnPlateau
允许根据某些验证测量值降低动态学习率。
学习率调度应在优化器更新后应用;例如,你应该这样写你的代码
’
Demo:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler.step()
大多数学习率调度器可以称为背靠背调度器(也称为链式调度器)。结果是,每个调度器一个接一个地应用于前一个调度器获得的学习率。
Demo:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler1.step()
scheduler2.step()
在文档的许多地方,我们将使用以下模板来引用调度器算法。
scheduler = ...
for epoch in range(100):
train(...)
validate(...)
scheduler.step()
scheduler 源码解析
参考:https://zhuanlan.zhihu.com/p/346205754?utm_medium=social&utm_oi=73844937195520
学习率调整类主要的逻辑功能就是每个 epoch
计算参数组的学习率,更新 optimizer
对应参数组中的lr
值,从而应用在optimizer
里可学习参数的梯度更新。所有的学习率调整策略类的父类是torch.optim.lr_scheduler._LRScheduler
,基类 _LRScheduler
定义了如下方法:
- step(epoch=None): 子类公用
- get_lr(): 子类需要实现
- get_last_lr(): 子类公用
- print_lr(is_verbose, group, lr, epoch=None): 显示 lr 调整信息
- state_dict(): 子类可能会重写
- load_state_dict(state_dict): 子类可能会重写
初始化 init:
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1, verbose=False):
.......
self.optimizer = optimizer
# Initialize epoch and base learning rates
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
self.last_epoch = last_epoch
........
self.optimizer.step = with_counter(self.optimizer.step)
self.optimizer._step_count = 0
self._step_count = 0
self.verbose = verbose
self.step()
初始化参数:
optimizer
就是优化器的实例last_epoch
是最后一次epoch
的index
,默认值是 -1,代表初次训练模型,此时会对optimizer
里的各参数组设置初始学习率initial_lr
。
若last_epoch
传入值大于 -1,则代表从某个 epoch
开始继续上次训练,此时要求optimizer
的参数组中有initial_lr
初始学习率信息。初始化函数内部的 with_counter
函数主要是为了确保lr_scheduler.step()
是在optimizer.step()
之后调用的. 注意在__init__函数最后一步调用了self.step()
,即_LRScheduler
在初始化时已经调用过一次step()
方法。
step
当模型完成一个 epoch
训练时,需要调用step()
方法,该方法里对last_epoch
自增之后,在内部上下文管理器类里调用子类实现的get_lr()
方法获得各参数组在此次 epoch
时的学习率,并更新到 optimizer
的param_groups
属性之中,最后记录下最后一次调整的学习率到self._last_lr
,此属性将在get_last_lr()
方法中返回。在这个方法中用到了上下文管理功能的内部类 _enable_get_lr_call
,实例对象添加了_get_lr_called_within_step
属性,这个属性可在子类中使用。
def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule. "
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
self._step_count += 1
class _enable_get_lr_call:
def __init__(self, o):
self.o = o
def __enter__(self):
self.o._get_lr_called_within_step = True
return self
def __exit__(self, type, value, traceback):
self.o._get_lr_called_within_step = False
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = self._get_closed_form_lr()
else:
values = self.get_lr()
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr
self.print_lr(self.verbose, i, lr, epoch)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
get_lr
get_lr()
方法是抽象方法,定义了更新学习率策略的接口,不同子类继承后会有不同的实现.其返回值是[lr1, lr2, …]结构
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == 0:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
get_last_lr
def get_last_lr(self):
""" Return last computed learning rate by current scheduler.
"""
return self._last_lr
print_lr
print_lr(is_verbose, group, lr, epoch=None)): 该方法提供了显示 lr 调整信息的功能
def print_lr(self, is_verbose, group, lr, epoch=None):
"""Display the current learning rate.
"""
if is_verbose:
if epoch is None:
print('Adjusting learning rate'
' of group {} to {:.4e}.'.format(group, lr))
else:
epoch_str = ("%.2f" if isinstance(epoch, float) else
"%.5d") % epoch
print('Epoch {}: adjusting learning rate'
' of group {} to {:.4e}.'.format(epoch_str, group, lr))
其他接口
state_dict()
: 以字典 dict 形式返回当前实例除self.optimizer
之外的其他所有属性信息load_state_dict(state_dict)
: 重新载入之前保存的状态信息