pytorch中AdaGrad优化器源码解读

1. AdaGrad算法

花书中截图,随便找了一张。
在这里插入图片描述

2.源码

def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    """
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()

    for group in self.param_groups:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad # 梯度
            state = self.state[p]

            state['step'] += 1

            if group['weight_decay'] != 0:
                if p.grad.is_sparse:
                    raise RuntimeError("weight_decay option is not compatible with sparse gradients")
                # grad = grad + weight*grad
                grad = grad.add(p, alpha=group['weight_decay'])
            # 更新学习率时,采用随step递增后的lr
            # clr = lr/[1-(step-1)*decay)]
            clr = group['lr'] / (1 + (state['step'] - 1) * group['lr_decay'])

            if grad.is_sparse: # 判断稀疏性
                grad = grad.coalesce()  # the update is non-linear so indices must be unique
                grad_indices = grad._indices()
                grad_values = grad._values()
                size = grad.size()

                def make_sparse(values):
                    constructor = grad.new
                    if grad_indices.dim() == 0 or values.dim() == 0:
                        return constructor().resize_as_(grad)
                    return constructor(grad_indices, values, size)
                state['sum'].add_(make_sparse(grad_values.pow(2)))
                std = state['sum'].sparse_mask(grad)
                std_values = std._values().sqrt_().add_(group['eps'])
                p.add_(make_sparse(grad_values / std_values), alpha=-clr)
            else:
                # state['sum']为累计平方梯度
                #  s = s+ value*grad*grad
                state['sum'].addcmul_(grad, grad, value=1)
                # std = sqrt(sum)+ eps
                std = state['sum'].sqrt().add_(group['eps'])
                # 应用更新
                # p(t)=p(t-1) - clr*grad/std
                p.addcdiv_(grad, std, value=-clr)

    return loss

Pytorch源码与花书中图略有不同,区别在于,源码中学习率增加decay。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值