Optimizer
torch.optim
每个 optimizer
中有一个 param_groups
维护一组参数更新,其中包含了诸如学习率之类的超参数。通过访问 pprint(opt.param_group)
可以查看或者修改
[
{'dampening': 0,
'lr': 0.01,
'momentum': 0,
'nesterov': False,
'params': [Parameter containing:
tensor([[-0.4239, 0.2810, 0.3866],
[ 0.1081, -0.3685, 0.4922],
[ 0.1043, 0.5353, -0.1368],
[ 0.5171, 0.3946, -0.3541],
[ 0.2255, 0.4731, -0.4114]], requires_grad=True),
Parameter containing:
tensor([ 0.3145, -0.5053, -0.1401, -0.1902, -0.5681], requires_grad=True)],
'weight_decay': 0},
{'dampening': 0,
'lr': 0.01,
'momentum': 0,
'nesterov': False,
'params': [Parameter containing:
tensor([[[[ 0.0476, 0.2790],
[ 0.0285, -0.1737]],
[[-0.0268, 0.2334],
[-0.0095, -0.1972]],
[[-0.1588, -0.1018],
[ 0.2712, 0.2416]]]], requires_grad=True),
Parameter containing:
tensor([ 0.0690, -0.2328, -0.0965], requires_grad=True)],
'weight_decay': 0}
]
每一组 param_group 有不同的参数,对应模型中不同的 Parameter
。
- 基本操作
add_param_group
用于添加新的参数。API
传入应该是一个字典,类似于 optim.param_group
列表中的一个元素,包含了诸如 params
, lr
等参数。如果没有则用初始化 Optimizer
时的默认参数代替。
.step()
API
更新所有参数组。
更新时满足两个条件:
- 参数是有 grad (requires_grad = True)。
- 在
optimizer
的param_group
内。
Pytorch Lightning
关于手动执行参数更新参考文档。关于配置 Lightning_Module
优化器参考此处。
Multiple Optimizer
如果 configure_optimizers
返回多个optimizer
,在 training_step
中会增加一个额外的 optimizer_idx
参数。此时根据需要的更新参数计算 Loss 然后返回。
Manual Optimization
如果需要手动执行参数更新。Set self.automatic_optimization=False
in your LightningModule’s __init__
。
Use the following functions and call them manually:
-
self.optimizers()
to access your optimizers (one or multiple). Useself.lr_schedulers()
to access your schedulers. -
optimizer.zero_grad()
to clear the gradients from the previous training step -
self.manual_backward(loss)
instead ofloss.backward()
-
optimizer.step()
to update your model parameters.scheduler.step()
to schedule your learning rate.
官方示例
from pytorch_lightning import LightningModule
class MyModel(LightningModule):
def __init__(self):
super().__init__()
# Important: This property activates manual optimization.
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
opt = self.optimizers()
loss = self.compute_loss(batch)
self.manual_backward(loss)
# accumulate gradients of `n` batches
if (batch_idx + 1) % n == 0:
opt.step()
opt.zero_grad()