文章目录
1. 问题描述
问题描述:
在使用 torch.optim.lr_scheduler.StepLR 学习率调度器时,尝试调节 last_epoch 参数。last_epoch 的默认值为 -1,但是当我将其修改为 2 时,代码报错:KeyError: “param ‘initial_lr’ is not specified in param_groups[0] when resuming an optimizer”。
错误分析:
报错提示表明,当 last_epoch 不为 -1 时,优化器的参数组(param_groups)中必须显式地包含 initial_lr 字段。否则,StepLR 调度器会尝试从 param_groups 中读取 initial_lr,如果没有这个字段就会报错。
报错完整代码:
import torch
from torch import optim
import torch.nn as nn
# 创建一个简单的模型
net = nn.Linear(3, 4)
def train():
# 在初始化优化器时显式指定每个参数组的初始学习率
optimizer = optim.Adam(net.parameters(), lr=0.1)
# 设置 StepLR 调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1, last_epoch=2)
for epoch in range(3):
# 打印当前epoch,调度器的学习率和优化器的学习率
print(f"Epoch {epoch}:")
print("scheduler.get_last_lr():", scheduler.get_last_lr())
print("optimizer.param_groups['lr']:", [group["lr"] for group in optimizer.param_groups])
print("=======================================================================")
# 执行优化步骤
optimizer.step()
# 更新学习率
scheduler.step()
if __name__ == "__main__":
train()
2. StepLR源码解析
首先,我们查看 StepLR 调度器的源码,在其父类 LRScheduler
中,我们发现如下代码:
class LRScheduler:
def __init__(self, optimizer, last_epoch=-1, verbose=False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
# Initialize epoch and base learning rates
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
self.last_epoch = last_epoch
我们看到,当 last_epoch != -1
时,LRScheduler
会检查每个优化器的 param_groups
中是否存在 initial_lr
。如果缺少该字段,就会抛出 KeyError
。
3. 调试与打印 optimizer.param_groups
为了深入理解问题,我们可以通过打印 optimizer.param_groups
中的内容来进一步调试:
optimizer = optim.Adam(net.parameters(), lr=0.1)
for param_group in optimizer.param_groups:
print(param_group)
代码输出:
{'params': [Parameter containing:
tensor([[-0.5740, 0.1961, 0.0222],
[-0.0027, -0.5017, 0.4377],
[-0.1409, 0.2160, -0.3976],
[ 0.0839, 0.2271, 0.2851]], requires_grad=True), Parameter containing:
tensor([ 0.2809, -0.0746, -0.1029, 0.2433], requires_grad=True)], 'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None}
我们可以看到,param_groups
中并没有 initial_lr
字段,这正是报错的根本原因。
4. 解决报错
要解决这个问题,我们可以显式地为每个参数组添加 initial_lr
字段。具体方法是,在优化器初始化后,手动为每个 param_group
添加 initial_lr
optimizer = optim.Adam(net.parameters(), lr=0.1)
for param_group in optimizer.param_groups:
param_group["initial_lr"] = 0.1 # 手动添加 initial_lr 字段
print(param_group)
代码输出:
{'params': [Parameter containing:
tensor([[ 0.2887, 0.5005, -0.5337],
[-0.0699, -0.1790, -0.1502],
[-0.5572, 0.1958, 0.3659],
[ 0.2058, 0.2993, 0.4837]], requires_grad=True), Parameter containing:
tensor([ 0.5452, -0.1911, 0.0235, -0.3962], requires_grad=True)], 'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False,
'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'initial_lr': 0.1}
添加 initial_lr
后,代码就能够正常运行,不再报错。
修改后完整代码:
import torch
from torch import optim
import torch.nn as nn
# 创建一个简单的模型
net = nn.Linear(3, 4)
def train():
# 在初始化优化器时显式指定每个参数组的初始学习率
optimizer = optim.Adam(net.parameters(), lr=0.1)
for param_group in optimizer.param_groups:
param_group["initial_lr"] = 0.1
# 设置 StepLR 调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1, last_epoch=2)
for epoch in range(3):
# 打印当前epoch,调度器的学习率和优化器的学习率
print(f"Epoch {epoch}:")
print("scheduler.get_last_lr():", scheduler.get_last_lr())
print("optimizer.param_groups['lr']:", [group["lr"] for group in optimizer.param_groups])
print("=======================================================================")
# 执行优化步骤
optimizer.step()
# 更新学习率
scheduler.step()
if __name__ == "__main__":
train()
5. 探讨 last_epoch 参数作用
5.1 当last_epoch = 1,2,3,4… 时
测试代码:
import torch
from torch import optim
import torch.nn as nn
# 创建一个简单的模型
net = nn.Linear(3, 4)
def train():
# 在初始化优化器时显式指定每个参数组的初始学习率
optimizer = optim.Adam(net.parameters(), lr=0.1)
for param_group in optimizer.param_groups:
param_group["initial_lr"] = 0.1
# 设置 StepLR 调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1, last_epoch=2)
for epoch in range(3):
# 打印当前epoch,调度器的学习率和优化器的学习率
print(f"Epoch {epoch}:")
print("scheduler.state_dict():", scheduler.state_dict())
print("optimizer.param_groups['lr']:", [group["lr"] for group in optimizer.param_groups])
print("=======================================================================")
# 执行优化步骤
optimizer.step()
# 更新学习率
scheduler.step()
if __name__ == "__main__":
train()
代码输出:
Epoch 0:
scheduler.state_dict(): {'step_size': 1, 'gamma': 0.1, 'base_lrs': [0.1], 'last_epoch': 3, 'verbose': False, '_step_count': 1, '_get_lr_called_within_step': False, '_last_lr': [0.010000000000000002]}
optimizer.param_groups['lr']: [0.010000000000000002]
=======================================================================
Epoch 1:
scheduler.state_dict(): {'step_size': 1, 'gamma': 0.1, 'base_lrs': [0.1], 'last_epoch': 4, 'verbose': False, '_step_count': 2, '_get_lr_called_within_step': False, '_last_lr': [0.0010000000000000002]}
optimizer.param_groups['lr']: [0.0010000000000000002]
=======================================================================
Epoch 2:
scheduler.state_dict(): {'step_size': 1, 'gamma': 0.1, 'base_lrs': [0.1], 'last_epoch': 5, 'verbose': False, '_step_count': 3, '_get_lr_called_within_step': False, '_last_lr': [0.00010000000000000003]}
optimizer.param_groups['lr']: [0.00010000000000000003]
=======================================================================
分析:
- 当
last_epoch 设置为 2
,Epoch = 0 时,last_epoch
为 3,last_lr=0.01
- 当
last_epoch 设置为 3
,Epoch = 0 时,last_epoch
为 4,last_lr=0.01
- 当
last_epoch 设置为 4
,Epoch = 0 时,last_epoch
为 5,last_lr=0.01
- 为什么会出现这种现象呢?通过研究源码发现, 当 StepLR 参数
last_epoch
不等于 0时,StepLR
中会将第一步的last_epoch +=1
, 但是对于last_lr = lr * gamma
源码路径: /usr/local/lib/python3.9/site-packages/torch/optim/lr_scheduler.py 中的 LRScheduler 类中的step函数
def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule. "
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
self._step_count += 1
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = self._get_closed_form_lr()
else:
values = self.get_lr()
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr
self.print_lr(self.verbose, i, lr, epoch)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
该函数中:
- self._step_count 初始值为0,self._step_count += 1,因此无论last_epoch设置为多少,_step_count 都是从1 开始
- self._last_lr 的值来自于value,value变量的值来自于self.get_lr(),通过研究self.get_lr()函数源码发现,当满足
(self.last_epoch == 0) or (self.last_epoch % self.step_size != 0)
时候会返回原始 lr,否则返回[group['lr'] * self.gamma for group in self.optimizer.param_groups]
self.get_lr()源码:
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
5.2 当 last_epoch使用默认值 -1时
测试代码:
import torch
from torch import optim
import torch.nn as nn
# 创建一个简单的模型
net = nn.Linear(3, 4)
def train():
# 在初始化优化器时显式指定每个参数组的初始学习率
optimizer = optim.Adam(net.parameters(), lr=0.1)
# 设置 StepLR 调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
for epoch in range(3):
# 打印当前epoch,调度器的学习率和优化器的学习率
print(f"Epoch {epoch}:")
print("scheduler.state_dict():", scheduler.state_dict())
print("optimizer.param_groups['lr']:", [group["lr"] for group in optimizer.param_groups])
print("=======================================================================")
# 执行优化步骤
optimizer.step()
# 更新学习率
scheduler.step()
if __name__ == "__main__":
train()
代码输出:
Epoch 0:
scheduler.state_dict(): {'step_size': 1, 'gamma': 0.1, 'base_lrs': [0.1], 'last_epoch': 0, 'verbose': False, '_step_count': 1, '_get_lr_called_within_step': False, '_last_lr': [0.1]}
optimizer.param_groups['lr']: [0.1]
=======================================================================
Epoch 1:
scheduler.state_dict(): {'step_size': 1, 'gamma': 0.1, 'base_lrs': [0.1], 'last_epoch': 1, 'verbose': False, '_step_count': 2, '_get_lr_called_within_step': False, '_last_lr': [0.010000000000000002]}
optimizer.param_groups['lr']: [0.010000000000000002]
=======================================================================
Epoch 2:
scheduler.state_dict(): {'step_size': 1, 'gamma': 0.1, 'base_lrs': [0.1], 'last_epoch': 2, 'verbose': False, '_step_count': 3, '_get_lr_called_within_step': False, '_last_lr': [0.0010000000000000002]}
optimizer.param_groups['lr']: [0.0010000000000000002]
=======================================================================
分析:
last_epoch
初始值为 0,第一次调用时没有进行更新,学习率保持为 0.1,之后才会开始更新。
6. 总结
- 当使用
torch.optim.lr_scheduler.StepLR
时,如果last_epoch != -1
,则必须在优化器的参数组中显式指定initial_lr
,否则会报KeyError
。 - 通过手动为优化器的参数组添加 initial_lr 字段,可以解决该错误。
last_epoch
参数用于控制调度器的初始状态,决定从哪一轮开始调整学习率。