通过torch.optim导入优化器。优化器继承class Optimizer,从Optimizer开始分析,再到SGD和Adam。
Class Optimizer:
def __init__(self, params, defaults):
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
self.state = defaultdict(dict)
self.param_groups = []
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
self.add_param_group(param_group)
params为传入优化器的模型参数即model.parameters(),defaults为其他参数,如在SGD中:
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov)
super(SGD, self).__init__(params, defaults)
然后通过super调用父类Optimizer的初始化方法。在Optimizer中将模型参数params以及优化器参数放在self.add_param_group()中。
opt.zero_grad:设置参数p.grad.detach_()和p.grad.zero_()。
def zero_grad(self):
r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
Optimizer的存储与恢复:
def __getstate__(self):
return {
'defaults': self.defaults,
'state': self.state,
'param_groups': self.param_groups,
}
def __setstate__(self, state):
self.__dict__.update(state)
def state_dict(self):
def pack_group(group):
packed = {k: v for k, v in group.items() if k != 'params'}
packed['params'] = [id(p) for p in group['params']]
return packed
param_groups = [pack_group(g) for g in self.param_groups]
# Remap state to use ids as keys
packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
return {
'state': packed_state,
'param_groups': param_grou