Optimizer
更新参数主要是靠 step
函数
SGD类__init__函数
#params 网络模型的参数
#余参数被打包进字典中命名为defaults
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
#打包defaults
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
#将params和defaults传给父类Optimizer的__init__函数
super(SGD, self).__init__(params, defaults)
optimizer类__init__函数
def __init__(self, params, defaults):
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults
self._hook_for_profile()
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
#作用在于当字典里的 key 被查找但不存在时,返回的不是keyError而是一个默认值,
#此处defaultdict(dict)返回的默认值会是个空字典
self.state = defaultdict(dict)
self.param_groups = []
#param_groups-value
#params-key
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
#创建self.param_groups
for param_group in param_groups:
self.add_param_group(param_group)
optimizer类add_param_group函数
#将param_group放进self.param_groups中
def add_param_group(self, param_group):
assert isinstance(param_group, dict), "param group must be a dict"
#取出网络参数
params = param_group['params']
if isinstance(params, torch.Tensor):
param_group['params'] = [params]
elif isinstance(params, set):
raise TypeError('optimizer parameters need to be organized in ordered collections, but '
'the ordering of tensors in sets will change between runs. Please use a list instead.')
else:
#将网络参数重新组织放到列表中
param_group['params'] = list(params)
for param in param_group['params']:
if not isinstance(param, torch.Tensor):
raise TypeError("optimizer can only optimize Tensors, "
"but one of the params is " + torch.typename(param))
if not param.is_leaf:
raise ValueError("can't optimize a non-leaf Tensor")
#self.defaults
for name, default in self.defaults.items():
if default is required and name not in param_group:
raise ValueError("parameter group didn't specify a value of required optimization parameter " +
name)
else:
param_group.setdefault(name, default)
params = param_group['params']
if len(params) != len(set(params)):
warnings.warn("optimizer contains a parameter group with duplicate parameters; "
"in future, this will cause an error; "
"see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3)
param_set = set()
for group in self.param_groups:
param_set.update(set(group['params']))
if not param_set.isdisjoint(set(param_group['params'])):
raise ValueError("some parameters appear in more than one parameter group")
#将字典param_group放进列表self.param_groups中
self.param_groups.append(param_group)
sgd类step函数
@torch.no_grad()
def step(self, closure=None):
#网络模型参数和优化器的参数都保存在列表 self.param_groups 的元素中
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
d_p_list = []
momentum_buffer_list = []
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
lr = group['lr']
#可以通过两层循环访问网络模型的每一个参数 p
for p in group['params']:
if p.grad is not None:
params_with_grad.append(p)
#获取到梯度d_p
d_p_list.append(p.grad)
state = self.state[p]
if 'momentum_buffer' not in state:
momentum_buffer_list.append(None)
else:
momentum_buffer_list.append(state['momentum_buffer'])
#封装sgd函数
F.sgd(params_with_grad,
d_p_list,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov)
# update momentum_buffers in state
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
state = self.state[p]
state['momentum_buffer'] = momentum_buffer
return loss
function类sgd函数
def sgd(params: List[Tensor],
d_p_list: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool):
for i, param in enumerate(params):
#获取到梯度d_p
d_p = d_p_list[i]
if weight_decay != 0:
d_p = d_p.add(param, alpha=weight_decay)
#根据优化器参数设置是否使用 momentum或者nesterov再对参数进行调整
if momentum != 0:
buf = momentum_buffer_list[i]
if buf is None:
buf = torch.clone(d_p).detach()
momentum_buffer_list[i] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
#作用是对参数进行更新
param.add_(d_p, alpha=-lr)