[docs] @torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups: #对参数进行遍历
params_with_grad = [] #有梯度的网路参数收集列表
d_p_list = [] #收集网络参数的梯度列表
momentum_buffer_list = []
#以下为一些超参数的收集
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
lr = group['lr']
for p in group['params']: 对而网络参数进行逐个遍历更新
if p.grad is not None:
params_with_grad.append(p) #
d_p_list.append(p.grad)
state = self.state[p]
if 'momentum_buffer' not in state:
momentum_buffer_list.append(None)
else:
momentum_buffer_list.append(state['momentum_buffer']) #将所有的动量缓存放进去列表中
F.sgd(params_with_grad, 其内部的函数不清楚如何操作,但是最终结果就是同时更新网络参数权重以及动量缓存
d_p_list,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov) #利用函数sgd对列表进行更新 函数传入列表对象
# update momentum_buffers in state
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
state = self.state[p]
state['momentum_buffer'] = momentum_buffer
return loss
SGD的step()方法
最新推荐文章于 2024-04-26 16:51:33 发布