PyTorch-LAMB 开源项目教程

PyTorch-LAMB 开源项目教程

pytorch-lambImplementation of https://arxiv.org/abs/1904.00962项目地址:https://gitcode.com/gh_mirrors/py/pytorch-lamb

1. 项目的目录结构及介绍

PyTorch-LAMB 项目的目录结构如下:

pytorch-lamb/
├── pytorch_lamb/
│   ├── __init__.py
│   ├── lamb.py
├── tests/
│   ├── __init__.py
│   ├── test_lamb.py
├── README.md
├── setup.py

目录介绍

  • pytorch_lamb/:包含项目的主要代码文件。
    • __init__.py:模块初始化文件。
    • lamb.py:实现 LAMB 优化器的主要代码文件。
  • tests/:包含项目的测试代码。
    • __init__.py:测试模块初始化文件。
    • test_lamb.py:针对 LAMB 优化器的测试代码。
  • README.md:项目说明文档。
  • setup.py:用于安装项目的脚本。

2. 项目的启动文件介绍

项目的启动文件是 pytorch_lamb/lamb.py,该文件实现了 LAMB 优化器的主要逻辑。以下是该文件的主要内容:

import math
import torch
from torch.optim.optimizer import Optimizer

class Lamb(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, clamp_value=float('inf')):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, clamp_value=clamp_value)
        super(Lamb, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instead.')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

                weight_norm = p.data.norm(2).clamp(0, group['clamp_value'])

                trust_ratio = weight_norm / (exp_avg.norm(2) + group['eps

pytorch-lambImplementation of https://arxiv.org/abs/1904.00962项目地址:https://gitcode.com/gh_mirrors/py/pytorch-lamb

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

喻建涛

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值