迎接DeepSeek开源周[Kimi先开为敬]发布开源最新Muon优化器可替代 AdamW计算效率直接翻倍

Muon优化器在小规模语言模型训练中表现出色,但在大规模模型训练中的可扩展性尚未得到证实。月之暗面通过系统分析和改进,成功将 Muon 应用于 3B/16B 参数的 MoE 模型训练,累计训练 5.7 万亿 token。结果表明,Muon 可以替代 AdamW 成为大规模 LLM 训练的标准优化器,在训练效率和模型性能方面具有显著优势。通过开源实现、Moonlight 模型和训练中间检查点,本文旨在推动可扩展优化技术研究,加速 LLM 训练方法发展。

🔗 资源链接

🛠 核心改进

1. 权重衰减机制

通过在 Muon 中引入标准 AdamW 式权重衰减,有效解决了模型参数和层输出 RMS 过大的问题。

参数更新公式

使用如下数学公式描述 Muon 优化器的权重更新规则:

W t = W t − 1 − η t ( O t + λ W t − 1 ) W_{t} = W_{t-1} - \eta_{t} \left( O_{t} + \lambda W_{t-1} \right) Wt=Wt1ηt(Ot+λWt1)

公式解析

  • W t W_t Wt:第 t 步的权重矩阵
  • η t \eta_t ηt:动态学习率
  • O t O_t Ot:当前梯度估计量
  • λ \lambda λ:权重衰减系数
  • ( ⋅ ) \left( \cdot \right) ():自适应缩放运算符

该公式实现了:

  1. 权重衰减项 λ W t − 1 \lambda W_{t-1} λWt1 的显式分离
  2. 梯度方向与正则化项的联合优化
  3. 通过动态学习率 η t \eta_t ηt 实现训练稳定性控制在这里插入图片描述

2. 参数更新尺度调整:一致的更新 RMS

在这里插入图片描述

改进参数更新规则,确保不同形状矩阵间的更新 RMS 一致性,显著提升训练稳定性。,提出基于矩阵维度特性的调整策略:

缩放规则

对每个矩阵按其最大维度进行归一化处理:
max ⁡ ( A , B ) \sqrt{\max(A,B)} max(A,B)

更新公式

调整后的参数更新规则实现为:
W t = W t − 1 − η t ( 0.2 ⋅ O t ⋅ max ⁡ ( A , B ) + λ W t − 1 ) W_{t} = W_{t-1} - \eta_{t} \left( 0.2 \cdot O_{t} \cdot \sqrt{\max(A,B)} + \lambda W_{t-1} \right) Wt=Wt1ηt(0.2Otmax(A,B) +λWt1)

关键改进

  1. 维度感知缩放:通过 max ⁡ ( A , B ) \sqrt{\max(A,B)} max(A,B) 补偿不同形状矩阵的梯度尺度差异
  2. 经验系数:Muon 的更新均方根误差控制在 0.2

3. 分布式实现优化

开发基于 ZeRO-1 风格的分布式版本,实现:

  • 内存占用优化
  • 通信效率提升(梯度同步频率降低 50%)

🧪 实验设计

模型架构

采用类 Deepseek-V3-Small 架构,并针对 Moonlight 模型需求进行微调。

数据集

使用 Kimi 团队提供的 5.7 万亿 token 数据集进行预训练。

训练流程

分阶段优化策略:

  1. 渐进式提升学习率
  2. 动态调整批量大小
  3. 多阶段数据质量优化

📊 实验结果

一致性更新 RMS

更新效果对比
调整学习率方法(Adjusted LR)表现最优,显著优于:

  • 基线方法(Baseline)
  • 仅保持与 AdamW 一致 RMS 的方法(Update Norm)

扩展性验证

Muon 在计算最优设置下仅需 52% 训练 FLOPs 即可达到 AdamW 同等性能。

预训练性能

Moonlight 模型在 1.2T token 时的表现显著优于 AdamW 训练的 Moonlight-A 模型。

微调表现

  • 优势场景:Muon 预训练+微调模型全面优于 AdamW 预训练+微调模型
  • 局限场景:当微调阶段切换优化器时,Muon 优势减弱

关键问题及回答

问题 1:Muon 优化器在大规模模型训练中引入权重衰减的具体作用是什么?

权重衰减在 Muon 优化器中的作用主要是防止模型权重过大,从而避免模型在训练过程中出现梯度爆炸或梯度消失的问题。具体来说,权重衰减通过在更新规则中加入一个正则化项来限制权重的增长。论文中的权重衰减公式如下:

[W_{t}=W_{t - 1}-\eta_{t}\left(O_{t}+\lambda W_{t - 1}\right)]

其中,(\lambda)是权重衰减比率。通过引入权重衰减,Muon 能够在训练过程中更好地控制权重的增长速度,确保模型在训练后期不会出现过大的权重值,从而提高模型的稳定性和泛化能力。实验结果表明,加入权重衰减后,Muon 在大规模模型训练中的性能显著提升。

问题 2:分布式 Muon 实现是如何提高内存效率和减少通信开销的?

分布式 Muon 实现通过以下方式提高内存效率和减少通信开销:

  1. Reduce-Scatter 操作:在数据并行组上进行梯度聚合,减少了全局通信的需求。
  2. 局部分区动量:使用局部分区动量进行动量应用,避免了全局动量的传输。
  3. Newton-Schulz 迭代:在本地计算 Newton-Schulz 迭代,只对需要的部分进行通信。
  4. DP Gather 操作:将局部分区更新矩阵聚合成全矩阵,这一步骤只在必要时进行一次。
  5. 减少冗余计算:在计算全更新矩阵后,丢弃不需要的部分,只保留局部更新矩阵进行下一步计算。

通过这些优化措施,分布式 Muon 在保持算法数学性质的同时,显著减少了内存使用和通信开销,提高了大规模模型训练的效率。

问题 3:Moonlight 模型在监督微调阶段的表现如何,与仅使用 AdamW 预训练和微调的模型相比有何优势?

Moonlight 模型在监督微调阶段表现出色,具体优势如下:

  1. 更高的性能:Muon 预训练和微调的模型在多个基准测试中均表现出比仅使用 AdamW 预训练和微调的模型更高的性能。例如,在 MMLU 和 GSM8k 基准上,Moonlight 模型分别取得了 70.0 和 77.4 的分数,而 AdamW 微调的模型分别为 66.7 和 70.7。
  2. 一致性优化:Muon 在整个训练过程中保持了更好的优化稳定性,避免了在微调阶段出现的梯度爆炸或梯度消失现象。
  3. 泛化能力:Muon 预训练和微调的模型在未见数据上的表现也更好,显示出更强的泛化能力。

注意

当微调阶段使用与预训练阶段不同的优化器时,Muon 并未表现出显著优势。这表明,为了充分发挥 Muon 的优势,建议在预训练和微调阶段都使用相同的优化器。

代码

class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    Muon内部运行标准的SGD动量优化,然后执行一个正交化后处理步骤,
    在该步骤中,每个二维参数的更新将被替换为最近的正交矩阵。
    为了高效地对每个更新进行正交化,我们使用牛顿 - 舒尔茨迭代,
    其优点是可以在GPU上以bfloat16格式稳定运行。

    一些警告:
    - 我们认为这个优化器不太可能在小批量训练中表现良好。
    - 我们认为它可能不太适合微调预训练模型,但我们还没有进行测试。

    参数:
        muon_params: 要由Muon优化的参数。
        lr: 学习率。更新的谱范数将为`lr`。(0.02是一个不错的默认值)
        momentum: 内部SGD使用的动量。(0.95是一个不错的默认值)
        nesterov: 是否在内部SGD中使用Nesterov风格的动量。(推荐)
        ns_steps: 要运行的牛顿 - 舒尔茨迭代的步数。(6步可能总是足够的)
        adamw_params: 要由AdamW优化的参数。`muon_params`中任何为{0, 1}维的参数
                      或者被检测为嵌入层或lm_head的参数也将由AdamW进行优化。
        adamw_lr: 内部AdamW的学习率。
        adamw_betas: 内部AdamW的betas参数。
        adamw_eps: 内部AdamW的epsilon参数。
        adamw_wd: 内部AdamW的权重衰减参数。
    """

    def __init__(
        self,
        lr=1e-3,
        wd=0.1,
        muon_params=None,
        momentum=0.95,
        nesterov=True,
        ns_steps=5,
        adamw_params=None,
        adamw_betas=(0.95, 0.95),
        adamw_eps=1e-8,
    ):
        # 定义默认参数字典
        defaults = dict(
            lr=lr,
            wd=wd,
            momentum=momentum,
            nesterov=nesterov,
            ns_steps=ns_steps,
            adamw_betas=adamw_betas,
            adamw_eps=adamw_eps,
        )

        # 将muon_params转换为列表
        params = list(muon_params)
        # 如果adamw_params不为None,将其转换为列表,否则设为空列表
        adamw_params = list(adamw_params) if adamw_params is not None else []
        # 将adamw_params的参数添加到params列表中
        params.extend(adamw_params)
        # 调用父类的初始化方法
        super().__init__(params, defaults)
        # 将参数分为使用Muon优化的和不使用Muon优化的两类
        for p in muon_params:
            # 对于muon_params中的每个参数,确保其维度为2
            assert p.ndim == 2, p.ndim
            # 标记该参数使用Muon进行优化
            self.state[p]["use_muon"] = True
        for p in adamw_params:
            # 对于adamw_params中的每个参数,标记其不使用Muon进行优化
            self.state[p]["use_muon"] = False

    def adjust_lr_for_muon(self, lr, param_shape):
        # 获取参数矩阵的前两个维度
        A, B = param_shape[:2]
        # 我们根据参数矩阵的大小调整学习率和权重衰减,如论文中所述
        adjusted_ratio = 0.2 * math.sqrt(max(A, B))
        # 计算调整后的学习率
        adjusted_lr = lr * adjusted_ratio
        return adjusted_lr

    def step(self, closure=None):
        """执行单个优化步骤。

        参数:
            closure (Callable, optional): 一个闭包,用于重新评估模型并返回损失。
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:

            ############################
            #           Muon           #
            ############################

            # 筛选出使用Muon优化的参数
            params = [p for p in group["params"] if self.state[p]["use_muon"]]
            # 获取当前组的学习率
            lr = group["lr"]
            # 获取当前组的权重衰减
            wd = group["wd"]
            # 获取当前组的动量
            momentum = group["momentum"]

            # 以分布式方式生成权重更新
            for p in params:
                # 进行合理性检查
                g = p.grad
                if g is None:
                    continue
                if g.ndim > 2:
                    g = g.view(g.size(0), -1)
                assert g is not None

                # 计算更新
                state = self.state[p]
                if "momentum_buffer" not in state:
                    # 如果动量缓冲区不存在,初始化为与梯度相同形状的零张量
                    state["momentum_buffer"] = torch.zeros_like(g)
                buf = state["momentum_buffer"]
                # 更新动量缓冲区
                buf.mul_(momentum).add_(g)
                if group["nesterov"]:
                    # 如果使用Nesterov动量,更新梯度
                    g = g.add(buf, alpha=momentum)
                else:
                    g = buf
                # 使用牛顿 - 舒尔茨迭代计算正交化更新
                u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])

                # 调整学习率
                adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)

                # 应用权重衰减
                p.data.mul_(1 - lr * wd)

                # 应用更新
                p.data.add_(u, alpha=-adjusted_lr)

            ############################
            #       AdamW backup       #
            ############################

            # 筛选出不使用Muon优化的参数
            params = [p for p in group["params"] if not self.state[p]["use_muon"]]
            lr = group['lr']
            # 获取AdamW的betas参数
            beta1, beta2 = group["adamw_betas"]
            # 获取AdamW的epsilon参数
            eps = group["adamw_eps"]
            # 获取AdamW的权重衰减参数
            weight_decay = group["wd"]

            for p in params:
                g = p.grad
                if g is None:
                    continue
                state = self.state[p]
                if "step" not in state:
                    # 如果步数信息不存在,初始化步数和一阶、二阶动量
                    state["step"] = 0
                    state["moment1"] = torch.zeros_like(g)
                    state["moment2"] = torch.zeros_like(g)
                # 更新步数
                state["step"] += 1
                step = state["step"]
                buf1 = state["moment1"]
                buf2 = state["moment2"]
                # 更新一阶动量
                buf1.lerp_(g, 1 - beta1)
                # 更新二阶动量
                buf2.lerp_(g.square(), 1 - beta2)

                # 计算更新
                g = buf1 / (eps + buf2.sqrt())

                # 计算偏差修正
                bias_correction1 = 1 - beta1**step
                bias_correction2 = 1 - beta2**step
                scale = bias_correction1 / bias_correction2**0.5
                # 应用权重衰减
                p.data.mul_(1 - lr * weight_decay)
                # 应用更新
                p.data.add_(g, alpha=-lr / scale)

        return loss
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI仙人掌

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值