探索混合专家(MoE)模型预训练:开源项目实操

作者:Mantaverse@知乎,哥伦比亚大学

MOE模型是什么

相比于传统的Dense模型,MoE(Mixture of Experts)模型在结构上进行了优化,特别是在线性投影层方面。MoE模型将单一的全连接层替换成多个专家层(例如,Mixtral使用了8个专家层)。在Switch Transformer的论文中,我们了解到,每次进行token预测时,模型会从这8个专家层中选出两个用于线性推理。这种方法旨在提高模型的性能和效率。

e5cb135a90693241456090fb5a6dc1f9.jpeg
Switch Transformer

这种设计有什么优势呢? 首先,它通过引入专家层,能够在每次计算中仅激活部分网络,从而减少计算资源的消耗。具体来说,MoE模型在推理阶段仅需计算两个被选中的专家层,而不是激活所有的专家层或整个网络。这使得计算量显著减少,从而降低了推理成本。

此外,虽然MoE模型整体参数量较大,但由于每次推理只使用部分专家层,实际参与计算的参数量远小于整个模型的参数总量。这意味着,即使MoE模型的参数量比传统Dense模型大得多,其实际计算成本却要低得多。例如,在Mixtral 8X7B中,每个专家层有7亿个参数,但每次推理只使用两个专家层,因此实际计算的参数量是14亿个,而不是所有8个专家层的总参数量。

更重要的是,MoE模型通过选择最适合当前任务的专家层,可以在不同的任务或数据输入下表现出更强的适应性和泛化能力。由于每个专家层可以专注于处理特定类型的数据或任务,整体模型的性能能够显著提升。结果是,MoE模型在推理阶段不仅具有更低的计算成本,还能在许多情况下比参数量更大的Dense模型表现更好。

具体到Mixtral 8X7B的实现上,Mixtral 8X7B使用了8个独立的专家层,每个专家层都有7亿个参数。在模型进行推理时,会动态选择两个最合适的专家层进行计算,从而实现高效的推理过程。这种设计不仅提高了模型的计算效率,还增强了其在处理大规模数据和复杂任务时的能力。

实现Moe 模型

基于我们在上一篇文章(从零预训练LLAMA3的完整指南:一个文件,探索Scaling Law)中实现的Llama3模型,只需要对模型做些许修改,便能够将Dense模型转化为MoE模型。

实现结果开源在:

https://github.com/hengjiUSTC/learn-llm/blob/main/pretrain/train_mixtral.py

实现步骤拆解

在Dense模型中,MLP的实现如下:

c57cb301600a5ea704320f470c4bd279.jpeg

我们希望把这个组件替换成8个专家层。首先,我们复用MLP的定义,因为本质上,专家层的结构和MLP是一样的:

bd273bb1b02c4229dd589491070a4874.jpeg

在此基础上,我们定义了一个专门处理8个专家层选择逻辑的gate层。在初始化中,我们定义了多个专家层和一个gate。gate的实现非常简单。在每次预测任务中,针对当前的hidden state,通过gate映射到一个(1,x)维度的结果中,结果中较大的值对应我们要选择的专家编号。

476b74bdf8e6af93395dc34fe370c6a1.jpeg

向前传播过程定义如下:

96ddf52a070fbcc2ffb1fb195869d1c7.jpeg

我们接下来拆解向前传播的每一步的效果

1. 初始化和形状调整

获取输入 hidden_states 的形状信息,并在训练时添加抖动噪声以增强模型的鲁棒性:

batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
    hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states = hidden_states.view(-1, hidden_dim)

2. 计算路由器的logits

通过一个门控网络(gate)计算出每个输入token对应的专家权重:

router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)

3. 初始化和创建专家掩码

初始化 final_hidden_states 用于存储计算结果,并创建专家掩码以便索引:

final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_exp

4. 循环计算专家层输出

遍历所有专家层,计算每个专家层的输出,并累加到最终隐藏状态:

for expert_idx in range(self.num_experts):
    expert_layer = self.experts[expert_idx]
    idx, top_x = torch.where(expert_mask[expert_idx])
    current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
    current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

5. 恢复形状并返回结果

将 final_hidden_states 重新调整为原始输入的形状,并返回最终隐藏状态和路由logits:

final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

实现完成最重要的MoE层后,其他的网络结构和llam3基本没有区别,我们就不重新赘述了,感兴趣的朋友可以看之前的文章:从零预训练LLAMA3的完整指南:一个文件,探索Scaling Law

预训练效果对比

本次试验的结果开源在:

https://github.com/hengjiUSTC/learn-llm/blob/main/pretrain/mixtral_result.ipynb

我们对比了在不同setup下MoE模型的效果,对比了三个模型在同样的训练配置下的hellaswag评估效果:

  • • Dense模型,12层,12个attention head,4个kv head,推理时的激活参数180M

  • • Moe模型A,12层,12个attention head,4个kv head,8个expert 选择两个激活,推理时的激活参数266M

  • • Moe模型B,10层,8个attention head,4个kv head,8个expert 选择两个激活,推理时的激活参数150M

dd6660d854a39aed54ef559d4b35f1c2.jpeg

从结果可以看出,在相同的模型结构下,MoE模型A的效果远远超过了Dense模型。此外,MoE模型B在激活参数更少的情况下,依然能够在训练效果上超过Dense模型。这突显了专家层的优势:MoE模型能够以更少的激活参数实现更好的模型效果。

Deepseek MoE

在MoE模型的最新进展中,DeepSeek MoE模型进一步优化了专家的分配和选择机制。相比于Mixtral的8个专家层,DeepSeek通过引入更多且更小的专家,将8个专家层扩展到了64个专家层,并且每个专家层的大小缩小了4倍。

c9cccd4f5e699bb0e8329485d5074722.jpeg

此外,DeepSeek定义了一个名为share expert的特殊专家层,这个专家层会参与任何输入的计算,从而保证了计算的一致性和效果。在实验中,DeepSeek MoE模型展现出了极其优异的性能。即使使用接近Llama2 7B模型四分之一的激活参数量,通过2B的激活参数,DeepSeek MoE模型依然达到了与Dense模型相当的效果。

faa5e813d779cf5a2a47c7c79196b2d0.jpeg

结语

通过对比不同配置下的Dense模型和MoE模型,我们清楚地看到了MoE架构在提升性能和优化计算资源方面的巨大潜力。MoE模型不仅在相同参数量下表现优异,更在激活参数减少的情况下依然保持了高效的训练效果。

特别是DeepSeek MoE模型,通过增加专家层数量和引入share expert的创新机制,大幅提升了计算效率和模型效果。DeepSeek MoE在使用更少激活参数的前提下,依然能够达到与大型Dense模型相当的性能,展示了其在处理复杂任务中的独特优势。

本次试验的训练代码和结果全部开源,欢迎大家关注我的github repo,里面会发布更多模型相关的实操代码:https://github.com/hengjiUSTC/learn-llm/tree/main

——The  End——

b00324daa1cfd27cc6993c6cb3173278.gif

分享

收藏

点赞

在看

cd1b070727469d90fc9871414fc7d8d7.gif

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值