优化器pytoch代码分析

本文深入探讨了PyTorch中的优化器实现,从Optimizer基类到具体的SGD和Adam优化器。详细分析了参数的初始化、zero_grad操作、状态的存储与恢复过程,并解释了step函数中不同优化器的更新逻辑,特别是momentum和weight_decay的影响。同时,对Adam优化器的自适应学习率进行了说明。
摘要由CSDN通过智能技术生成

通过torch.optim导入优化器。优化器继承class Optimizer,从Optimizer开始分析,再到SGD和Adam。

Class Optimizer:

def __init__(self, params, defaults):
    torch._C._log_api_usage_once("python.optimizer")
    self.defaults = defaults

    if isinstance(params, torch.Tensor):
        raise TypeError("params argument given to the optimizer should be "
                        "an iterable of Tensors or dicts, but got " +
                        torch.typename(params))

    self.state = defaultdict(dict)
    self.param_groups = []

    param_groups = list(params)
    if len(param_groups) == 0:
        raise ValueError("optimizer got an empty parameter list")
    if not isinstance(param_groups[0], dict):
        param_groups = [{'params': param_groups}]

    for param_group in param_groups:
        self.add_param_group(param_group)

params为传入优化器的模型参数即model.parameters(),defaults为其他参数,如在SGD中:

defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov)

super(SGD, self).__init__(params, defaults)

然后通过super调用父类Optimizer的初始化方法。在Optimizer中将模型参数params以及优化器参数放在self.add_param_group()中。

opt.zero_grad:设置参数p.grad.detach_()和p.grad.zero_()。

def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    for group in self.param_groups:
        for p in group['params']:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

Optimizer的存储与恢复:

def __getstate__(self):
    return {
        'defaults': self.defaults,
        'state': self.state,
        'param_groups': self.param_groups,
    }

def __setstate__(self, state):
    self.__dict__.update(state)

def state_dict(self):

    def pack_group(group):
        packed = {k: v for k, v in group.items() if k != 'params'}
        packed['params'] = [id(p) for p in group['params']]
        return packed
    param_groups = [pack_group(g) for g in self.param_groups]
    # Remap state to use ids as keys
    packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
                    for k, v in self.state.items()}
    return {
        'state': packed_state,
        'param_groups': param_grou
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值