手撕MOE混合专家模型:从零推导到DeepSeek稀疏共享专家实战(附PyTorch源码)

一、🌟前言:站在大模型进化的十字路口 🌟

近年来,大语言模型的发展已悄然进入「混合专家」(Mixture of Experts, MOE)的时代。当人们还在惊叹GPT-3的千亿参数时,行业顶尖玩家已通过MOE架构实现了参数效率的指数级突破——用更小的计算代价调度更大的模型容量。而国产大模型新星DeepSeek公布的稀疏共享专家(Sparse Share-Expert)技术,更是将这一领域推向了新的高潮。

但令人困惑的是:
🔍 为什么MOE成了大模型的必备架构?
🔍 专家路由算法如何平衡计算效率与模型性能?
🔍 DeepSeek的共享专家策略相比传统MOE有哪些颠覆性创新?

本文将带您开启一场从数学原理到工业级实践的深度旅程
1️⃣ 手撕MOE核心公式:从条件概率分解出发,推导门控网络的数学本质
2️⃣ PyTorch逐行实现:构建可扩展的专家模块库,复现GShard动态路由策略
3️⃣ DeepSeek魔改解析:揭秘稀疏共享专家的三阶优化(专家复用/梯度隔离/稀疏激活)
4️⃣ 实战效果对比:在8xV100集群上测试传统MOE vs Share-Expert的吞吐量差异

无论您是希望夯实MOE理论基础的算法工程师,还是渴望掌握大模型架构设计精髓的技术极客,这篇文章都将成为您技术武器库中的关键拼图。文末提供完整实验代码与消融实验数据包,期待与您在评论区深度探讨! 💻🚀

二、MoE的基础详解

1.MoE是什么?为什么要提出MoE架构?

MoE是基于Transformer架构的模型,是一种稀疏模型(Sparse Model)(论文原文:GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding)。Transformer,想必大家已经很熟悉了,现在几乎所有的大模型都是基于这个架构的,那么就剩下稀疏了。说到“稀疏”,我们得先知道它的反义词——“稠密”。现在,我们把一个AI大模型(不管多复杂)都想象成一个黑盒子,从一头扔进去一个问题,它内部由海量的参数组成的网络开始计算,从另一头吐出答案。而所谓“稠密”,是指几乎所有的参数都会对所有输入数据进行处理的模型。我们可以想象一个全员都是强迫症的公司,每次来活了,不管这活是写代码,还是设计海报,CEO一声令下,公司里面从程序员到设计师,再到扫地阿姨,所有员工都必须全部参与进来,一起头脑风暴,哪怕很多人的工作都和这半毛钱关系都没有。我们平时使用的很多模型都属于稠密模型,可想而知,这效率得多低啊!由此,稀疏模型就登场了。稀疏模型就好比是现代公司,分为很多部门、小组,比如说运营组、法律咨询组、数据分析组等等,来了一个任务,它不会搞全员的头脑风暴,只会由一个聪明的项目经理去找到底需要哪个或哪几个专家小组来处理,然后,只有被点到名的专家们才会被激活并且参与运算,这就是MoE的核心思想。

虽然MoE仍然有很巨大的参数总量,但是在处理每个具体的任务时,只调用其中的一小部分最合适的专家参数。举个例子,假设我们有两个参数量都是60B的大模型,一个是稠密的,一个是MoE架构的。当我们问“今天天气怎么样”时,那个稠密的模型60B的参数都会参与运算,600亿个,就好比是用加农炮打蚊子吧。那个MoE模型呢,虽然参数总量都是60B,但是内部可能分成了8个专家(打个比方哈),那么每个专家都是7B的参数量,当它分析问题的时候,项目经理觉得这个问题交给第3号和第5号就够了,于是实际参与计算的参数也就15B左右,这会带来什么效果呢?

显而易见,在同样的硬件条件下,MoE模型的效率更高,也就是说MoE更适合那些高吞吐量、高效率的场景。

                             

2.基础版本的MoE是如何工作的?

这是MoE架构对输入的分配机制:

这一部分就是传统Dense Model的Attention Layer,而这部分是将原来传统的FFN替换成的MoE Layer。

我们先看第一个输入,经过Attention  Layer之后,进入一个Router()之后,过一个FFN+softmax,计算出每个expert对应的每个probability(需要注意的是,这里有一个非常重要的超参“”,代表了选择2个expert),所以softmax之后,会进行probability distribution,由于k = 2,所以选择probability最大的那两个expert(这里是expert1和expert3的probability最大),最后将expert1和expert3做一个Aggregation(也就是加和),然后再通过Add + Normalization Layer得到最终的输出。后面的输入也是类似的。从这一点上看,MoE Model就是将Dense Model的Dense FFN Layer换成了Sparse MoE Layer。

【大模型面试】经典考题

问题:主流LLM为什么选用MoE架构?MoE相较Dense的核心优点有哪些?LLM不可能三角?

现在的大语言模型的发展已经证明scaling law是有效的,简单来说,就是从三个方面将模型scaling up ,这三个方面就是模型大小(参数量)、训练数据和计算量。当我们从三个方面scaling up之后模型的性能也会随之提升,所以领域内基本沿用scaling law的思路开发新的模型。但是对于Dense Model这类模型,沿用scaling law时就遇到了瓶颈,这个瓶颈就是LLM不可能三角(impossible triangle),就是对于Dense Model来说,有三个因素互相制约:

如果继续沿用scaling law的思想,那么必然我们会在size层面增加Dense Model的大小,但由于是Dense Model,所以虽然performance上去了,但是cost也会急剧增加(包括训练,也包括推理时的cost),这也就出现了impossible triangle。

而MoE Model做的主要工作就是将model size(parameters)和computational cost进行解耦

我们可以把模型的大小(参数量)做的很大,但是可以控制模型的计算花费,因为MoE并不是将所有参数都激活,而computational cost来源于activate的参数,而不是total parameters,这样就能增大size,使得模型的performance得到很好的提高,但是由于activate parameters的存在,将computational cost控制的很好,不让cost随着size的增大而急剧增大。从不可能三角的角度,MoE就比Dense Model好很多。

另一方面,MoE的优势可能也在于expert的引入。我们认为模型的知识都存储在FFN中,MoE将各个expert都设计成独立的FFN,也就是说每个expert就如同名字一样,是“专家”。能够让模型的知识更加专业化;同时,随着模型的size的增加,我们可以让模型学习更多的知识,让模型更加优秀。

第三方面,MoE比Dense Model要好的地方可能在于MoE中Sparse的架构。例如,常见的CNN这个模型,通过卷积层去寻找数据里的Sparse partern。当然,在早期的Machine Learning的时代,也有很多时候,人们故意去引入Sparse 的性质,其本质就是去寻找数据中的Sparse Partern,也就类似于矩阵论中的SVD

总的来说,MoE就类似于CNN的卷积层的设计,人为地加入了一些prior,这个prior就是通过人类的先验知识,或者是信息里面存在一些Sparse 的Partern,从而设计Sparse的架构去捕捉数据里的Low dim Partern。这样做的好处在于,会使我们的learning efficiency提高,这就是为什么CNN在图像上的效果要好于MLP,尤其是数据量比较小的情况下,这就是因为CNN引入了enductive bias,而MoE也是认为引入Sparse结构,使模型的学习效率更高了。所以同样的输入量的情况下,MoE要好于Dense Model。这只是我个人的理解,大家可以拿去批判式的思考哈。

我们一直在说MoE的expert不免给大家带来一些错觉,可以看上图,横轴是expert ID,纵坐标是domain specialization的百分比,这里也有不同的Domain:GitHub、arXiv、Wikipedia、Books、C4。这里要强调的一点是,其实我们所说的MoE也是有不同的Layer的,比如展示出来的Layer0、Layer7和Layer15,MoE中的每个Layer都有对应的expert,所以整个模型是由很多这样的MoE的Layer叠加起来的,并不存在某一个expert去学习、激活,而是一个Layer层甚至横跨不同的Layer中的不同expert激活、学习。

3.基础版本的MoE代码从零实现详解

(1)版本一:基础MoE架构PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicExpert(nn.Module):
    def __init__(self,feature_in,feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in,feature_out)

    def forward(self,x):
        return self.fc(x)
    

class BasicMoE(nn.Module):
    def __init__(self,feature_in,feature_out,num_experts):
        super().__init__()
        self.gate = nn.Linear(feature_in,num_experts)
        # output shape is (batch_size,num_experts)
        self.experts = nn.ModuleList(
            BasicExpert(
                feature_in,feature_out
            ) for _ in range(num_experts)
        )

    def forward(self,x):
        # x shape is (batch_size,feature_in)
        # feature_in 也可以叫做hidden_size或者hidden_dim
        expert_weight = self.gate(x) # expert_weights shape is (batch_szie,num_experts)
        expert_out_list = [
            expert(x).unsqueeze(1) for expert in self.experts
        ]
        # 每一个expert输出一个(batch_size,feature_out)

        expert_outputs = [
            expert_out
            for expert_out in expert_out_list
        ]

        # concat 起来 (batch_size, num_experts, feature_out)
        expert_output = torch.cat(expert_out_list, dim=1)

        # print(expert_output.size())
        # 注:可以选择对expert_output做softmax也可以选择不做
        expert_weight = F.softmax(expert_weight,dim = 1)

        expert_weight = expert_weight.unsqueeze(1) # (batch_size, 1, num_experts)
        
        output = expert_weight @ expert_output  # output shape is (batch_size,1,feature_out)
        return output.squeeze(1)  # 最终希望得到的维数为 (batch_size,feature_out)
    
def test_basic_moe():
    x = torch.rand(4,512)
    basic_moe = BasicMoE(512,128,4)
    output = basic_moe(x)
    print(output.shape)

test_basic_moe()
    

其中,每个MoE都对应着一个expert,而每个expert一般是FeadForward Network,FFN。为了简化代码并让大家能够清晰了解到MoE的整体架构,版本一中的FFN用一层Linear层代替了。

(2)版本二:把 expert 改成 swishGLU 版本的 FFN 专家(课后作业

我们将expert由Linear替换成swishGLU版本的FFN,作为大家课后练习作业

我们先深入了解一下什么是swishGLU?它版本的FFN如何coding?

1)什么是swishGLU?

自22年ChatGPT问世以来,大模型(Large Language Model, LLM)一般使用 Causal Language Model 的形式,属于 Transformers 中的 Decoder 部分,其中在 Decoder 的 Block 中有一个 FFN(FeadForward) 层,一般认为这部分参数用于存储知识。而标准的 FFN 一般有一个升维度和降维度的过程,一共有两个权重矩阵,用公式表示为

FFN(x) = ReLU(xW_1 + b_1)W_2 + b_2                           (1)

其中 x shape 是 (𝑏,𝑠,ℎ),w1 shape 是 (ℎ,4ℎ),w2 shape 是 (4ℎ,ℎ), w1 是升维(up),w2 是降维(down)

激活函数主要是为了实现神经网络学习输入和输出之间的复杂非线性关系而使用的一个函数。在公式 (1) 中,ReLU 是一个激活函数(Transfromers原版),可以替换成其他的激活函数,比如 BERT 开始用 Gaussian Error Linear Unit,GELU 比较多,随后就成了激活函数的主流选择,但是随着大模型的爆火以及 PaLM 模型的发布,大家开始慢慢使用 swishGLU 作为激活函数,并且作为一个主要的优化点。

SwiGLU(或者swishGLU,以下可能混用) 是 swish 激活函数和 GLU 门控单元的结合体,因此需要分别介绍两者的不同。

其中需要注意的是:在 T5 开始,很多模型(比如 PaLM )在FFN层都不用 bias 了,也就是说 FFN的公式变成了

FFN(x) = ActiveFunction(xW_1)W_2                               (2)

注意公式 2 和公式 1 的区别,一共没有 bias 一个有 bias,但具体得看不同模型的实现,并不能一概而论。

1.1)swish 激活函数

swish 是一个非线性函数(激活函数一般都是如此),具体公式为:

                        Swish = x\sigma (\beta x)                                                  (3)

其中,\beta是一个超参数,当\beta = 1 时,Swish 就变成了 SiLU (Sigmoid Linear Unit) , 大多数框架的默认实现 (如PyTorch、TensorFlow的nn.SiLU())使用的是\beta = 1的固定版本。

因此如果采用 swish 激活函数,FFN 的公式变成了

            FFN(W_1,W_2,x) = Swish(xW_1)W_2                            (4)

共有两个可学习的矩阵,其中 W_1,(ℎ,4ℎ) 是升维矩阵,W_2,(4ℎ,ℎ)是降低维度的矩阵。

1.2)GLU门控单元

GLU,Gated Linear Units,是一种门控结构(有参数,因此相对于普通的激活函数多了一个 gate 矩阵),通过 sigmoid 控制不同维度的激活。公式如下

GLU(W,x,V,b,c) = (Wx + b) \bigotimes sigmoid(Vx + c )              (5)

这里是不是熟悉 LSTM, GRU 的同学一下就理解,其中需要注意的是,b, c 对应的 bias 不是必须的。

对比公式 5 和公式 7,公式 7 中的​ \omega _{up}对应 公式 5 中的 𝑊,而\omega _{down} 对应公式 5 中的 𝑉 矩阵。

1.3)SwiGLU的表达式

而 SwiGLU 就是把门控函数替换成了 swish,并且去除掉了 bias 部分,以及把 FFN 层的一个 Linear 层替换成了 GLU 层,因此一共有三个可训练的参数矩阵, w1, w2, w3。

因此最终的公式表达为:

FFN(W_1,W_2,W_3,x) = W_2 \cdot (W_1x\bigotimes Swish(W_3x))                      (6)

而我们都知道 FFN 是一个升高维度,然后降低维度的过程,因此可以写成,W2 是一个降低维度的参数,W1 是升高维度的过程,而 W3 是一个 Gate 需要用到的参数矩阵。

FFN(\omega _{up},\omega _{down},\omega _{gate}) = \omega _{down}\cdot (\omega _{up}x\bigotimes Swish(\omega _{gate}x))            (7)

通过这个公式整体就非常的清晰理解使用 swiGLU 的 FFN。

而我们都知道在 basic 版本的 FFN,见公式1, 只有 \omega _{up}和 ​ \omega _{down}分别是 (h, 4h) 和(4h, h),因此整体参数是 8h^2

而公式7 中,一共有三个矩阵,如果想要实现总参数 8h^2,那么每一个参数矩阵的大小应该是 \frac{8h^2}{3}​,因此 \omega _{up},\omega _{gate}的shape应该是 (h,\frac{8h}{3})\omega _{down}的 shape 是(\frac{8h}{3},h)

假设输入的 hidden_dim 大小是 hidden_dim,那么中间层(up 后的维度)大小是 mid_dim, 具体计算逻辑如下:

mid_dim = int(8 * hidden_dim / 3)
# multiple_of:make SwiGLU hidden layer size multiple of large power of 2
mid_dim = multiple_of * ((mid_dim + multiple_of - 1) // multiple_of)

# multiple_of 一般设置为 256, LLaMA 和 GPT等模型

注意,在 LLM (大语言模型) 架构中,multiple_of 是一个用于优化计算效率的参数,通常设置为 256 或其他 2 的幂次方数(如 128、512 等),最终让 mid_dim 调整为 multiple_of 的整数倍。这样做有几个原因:

  1. 硬件优化:现代 GPU/TPU 在处理 2 的幂次方大小的张量时效率最高
  2. 内存对齐:确保内存对齐可以提高计算速度
  3. 并行计算效率:某些并行计算操作在处理规整的数字时效率更高
2)带有SwishGLU版本的FFN代码实现
class FFNExpert(nn.Module):
    def __init__(self, hidden_dim, dropout):   # LLM 进化之路, FFN 激活函数从 GELU -> SwiGLU
        super().__init__()  

        # 有一个 magic number 叫做 8/3
        hidden_dim = hidden_dim
        # 这里可以自己去优化成 multiple_of 的倍数
        mid_dim = hidden_dim * 8 // 3

        self.up = nn.Linear(hidden_dim, mid_dim, bias=False)
        self.down = nn.Linear(mid_dim, hidden_dim, bias=False)
        self.gate = nn.Linear(hidden_dim, mid_dim, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = self.dropout(
            self.down(
                # up 之后的 Shape 是(b, s, mid_dim)
                # gate 和 up 之后的Shape都是 (b, s, mid_dim)
                # 两者是 element-wise 相乘
                F.silu(
                    self.gate(x)
                ) * self.up(x)
            )
        )
        return out
3)课后作业参考答案
import torch
import torch.nn as nn
import torch.nn.functional as F

class FFNExpert(nn.Module):
    def __init__(self,hidden_dim,dropout):
        super().__init__()

        # 8 / 3
        hidden_dim = hidden_dim
        # 这里可以自己去优化成multiple_of 的倍数
        mid_dim = hidden_dim * 8 // 3

        self.up = nn.Linear(hidden_dim,mid_dim,bias = False)
        self.down = nn.Linear(mid_dim,hidden_dim,bias = False)
        self.gate = nn.Linear(hidden_dim,mid_dim,bias = False)

        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        out = self.dropout(
            self.down(
                # up 之后的shape is (batch_size,length,mid_dim)
                # gate 和 up 之后的shape is (batch_size,length,mid_dim)
                # 两者都是element-wise相乘
                F.silu(
                    self.gate(x)
                ) * self.up(x)
            )
        )

        return out
    

class BasicMoE(nn.Module):
    def __init__(self,feature_in,feature_out,num_experts):
        super().__init__()
        self.gate = nn.Linear(feature_in,num_experts)
        # output shape is (batch_size,num_experts)
        self.experts = nn.ModuleList(
            FFNExpert(
                feature_in, 0.1  # hidden_dim, dropout_rate
            ) for _ in range(num_experts)
        )

    def forward(self,x):
        # x shape is (batch_size,feature_in)
        # feature_in 也可以叫做hidden_size或者hidden_dim
        expert_weight = self.gate(x) # expert_weights shape is (batch_szie,num_experts)
        expert_out_list = [
            expert(x).unsqueeze(1) for expert in self.experts
        ]
        # 每一个expert输出一个(batch_size,feature_out)

        expert_outputs = [
            expert_out
            for expert_out in expert_out_list
        ]

        # concat 起来 (batch_size, num_experts, feature_out)
        expert_output = torch.cat(expert_out_list, dim=1)

        # print(expert_output.size())
        # 注:可以选择对expert_output做softmax也可以选择不做
        expert_weight = F.softmax(expert_weight,dim = 1)

        expert_weight = expert_weight.unsqueeze(1) # (batch_size, 1, num_experts)
        
        output = expert_weight @ expert_output  # output shape is (batch_size,1,feature_out)
        return output.squeeze(1)  # 最终希望得到的维数为 (batch_size,feature_out)
    
def test_basic_moe():
    x = torch.rand(4,512)
    basic_moe = BasicMoE(512,128,4)
    output = basic_moe(x)
    print(output.shape)

test_basic_moe()

三、SparseMoE,MoE LLM,了解现代MoE大模型怎么训练

1.Sparse MoE

我们利用Switch Transformers这篇文章作为介绍,这篇文章(Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity)是

William Fedus、Barret Zoph和Noam Shazeer于2022年发表在《Journal of Machine Learning Research》的一篇重要论文,提出了一种高效的稀疏激活模型——Switch Transformer,旨在以较低的计算成本实现大规模语言模型的训练。

和 Basic 区别是,MOE 选择 topK 个专家,然后对这 topK 个专家的输出进行加权求和,并且把输入样本变成了大模型中真实的输入 Shape,(batch, seq_len, hidden_dim)

2.Sparse MoE 的PyTorch代码实现及详解

(1)整体流程概述

  • 输入处理:将序列输入(batch_size,seq_len,hidden_dim)拆分成单个token(batch_size * seq_len,hidden_dim),以便路由层处理。
  • 路由选择:路由层(MoERouter)为每个token计算对所有专家的权重,选出Top-K专家及对应的权重。
  • 专家计算:遍历每个专家,收集该专家负责的token,用专家网络处理后,按权重加权融合到最终结果。
  • 结果还原:将处理后的token结果恢复为原始序列形状(batch_size,seq_len,hidden_dim)。

(2)完整代码及详细注释

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicExpert(nn.Module):
    def __init__(self,feature_in,feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in,feature_out)

    def forward(self,x):
        return self.fc(x)

class MOEConfig:
    def __init__(self,hidden_dim,expert_number,top_k,shared_experts_number = 2):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_expert_number = shared_experts_number


class MoERouter(nn.Module):
    def __init__(self,config):
        super().__init__()  
        self.gate = nn.Linear(config.hidden_dim,config.expert_number)
        # 但是后面只会选择top_k个专家

        self.expert_number = config.expert_number
        self.top_k = config.top_k

    def forward(self,x):
        # 假设expert_number = 8 , top_k = 2
        router_logits = self.gate(x)  # (batch_size * seq_len,expert_number)

        # 计算每一个专家的概率
        router_probs = F.softmax(router_logits,dim = 1, dtype = torch.float)

        # 计算top_k专家的输出
        # 注意 : top_k 是可以反向传播的
        router_weights , selected_experts_indices = torch.topk(
            router_probs,
            self.top_k,
            dim = -1
        )
        # 输出的两个元素router_weights , selected_experts_indices的维度都是 (batch_size * seq_len,top_k)

        # 对router_weights重新做归一化
        router_weights = router_weights / router_weights.sum(
            dim = -1,keepdim=True
        )
        router_weights = router_weights.to(x.dtype)

        expert_mask = F.one_hot(
            selected_experts_indices,
            num_classes=self.expert_number
        )
        ### 很重要 : 输出的维度是 (batch_size * seq_len,top_k,expert_number)

        expert_mask = expert_mask.permute(2,1,0)
        # 输出维度变成了 (expert_number,top_k,batch_size * seq_len)

        return router_logits,router_weights,selected_experts_indices,expert_mask
        # router_logits shape is (batch_size * seq_len,expert_number)
        # router_weights shape is (batch_size * seq_len,top_k)
        # selected_experts_indices shape is (batch_size * seq_len,top_k)
        # expert_mask shape is (expert_number,top_k,batch_size * seq_len)


class SparseMoE(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config

        self.top_k = config.top_k
        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number

    # 初始化专家
        self.experts = nn.ModuleList(
            BasicExpert(
                config.hidden_dim,
                config.hidden_dim
            ) for _ in range(config.expert_number)
        )
        self.router = MoERouter(config)

    def forward(self,x):
        # x shape is (batch_size,seq_len,hidden_dim)
        batch_size,seq_len,hidden_dim = x.size()

        # 因为是对token维度做计算,所以需要将x reshape 成 (batch_size * seq_len,hidden_dim)
        hidden_states = x.view(-1,hidden_dim)

        # 做相关的专家的计算
        router_logits,router_weights,selected_experts_indices,expert_masks = self.router(
            hidden_states
        )

        # expert_masks shape is (expert_number,top_k,batch_size * seq_len)
        # 最终的维度肯定是 (batch_size * seq_len,hidden_dim)
        final_hidden_states = torch.zeros(
            (batch_size * seq_len,hidden_dim),
             dtype = hidden_states.dtype,
             device = hidden_states.device
        )

        # 遍历每一个专家,把选中的这个专家的token的hidden_states加到final_hidden_states中
        # 例如:expert-0可能有100个token选中了:
        # token 的总数是 batch_size * seq_len
        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]

            # expert_masks shape is (expert_number,top_k,batch_size * seq_len)
            current_expert_mask = expert_masks[expert_idx]
            # 因为已经将expert_number取出来了,所以current_expert_mask就是一个二维的值了
            # current_expert_mask shape is (top_k,batch_size * seq_len)

            router_weights_idx , top_x = torch.where(current_expert_mask)
            """
            torch.where()的用法:
            where 返回矩阵中为true的行、列位置,这里分别代表top_k中第几个和tokens中的第几个
            """
            # router_weights_idx 是 0 or 1
            # 假设top_k是2,那么idx是代表在top_k里面的第几个(one-hot编码中的0 or 1)
            # router_weights_idx 表示这个token是作为当前专家的top1还是top2

            # top_x 是 token 在 batch_size * seq_len 中的位置索引
            # 例如对于 batch_size = 2,seq_len = 4 的输入:
            # top_x 的值范围是 0 - 7 , 表示在展平后的8个 token 中的位置
            # router_weights_idx 和 top_x 都是一维的值
            # router_weights_idx 是用来选择weight的
            # top_x 是用来选 hidden_states的

            # hidden_states shape is (batch_size * seq_len,hidden_dim)
            current_state = hidden_states.unsqueeze(0)[:,top_x,:].reshape(-1,hidden_dim)
            # hidden_states在做完unsqueeze(0)之后的维度为 (1,batch_size * seq_len,hidden_dim)
            # [:,top_x,:] 后的维度为 (1,top_x,hidden_dim)
            # 最终的current_state shape is (selected_token_number,hidden_dim)

            current_state = expert_layer(current_state)

            # router_weights shape is (batch_size * seq_len,top_k)
            # top_x 对应的是token的维度
            current_token_router_weight = router_weights[top_x,router_weights_idx]
            # 最终的current_token_router_weight 的维度是 (selected_token_number)
            current_token_router_weight = current_token_router_weight.unsqueeze(-1)
            # 现在最终的current_token_router_weight 的维度就变成了 (selected_token_number,1)

            current_hidden_states = current_state * current_token_router_weight
            # current_state shape (selected_token_number,hidden_dim)
            # current_token_router_weight shape (selected_token_number,1)
            ### 注意 : 这里有广播 ###
            # 最终的输出current_hidden_states shape is (selected_token_number,hidden_dim)

            final_hidden_states.index_add_(
                0,
                top_x,
                current_hidden_states.to(hidden_states.dtype)
            )
        # 把final_hidden_states还原到原来的shape
        final_hidden_states = final_hidden_states.reshape(batch_size,seq_len,hidden_dim)

        return final_hidden_states,router_logits
    # shape is (batch_size * seq_len,expert_number)

def test_sparse_moe():
    x = torch.rand(2,4,16)
    config = MOEConfig(16,2,2)
    token_level_moe = SparseMoE(config)
    out = token_level_moe(x)
    print(out[0].shape,out[1].shape)

test_sparse_moe()

(3)关键变量维度与逻辑解析

以下结合测试用例config = MOEConfig(16,2,2)(hidden_dim = 16 , expert_number = 2 , top_k = 2)和输入x = (2,4,16),逐一解析关键变量的维度及变换原因:

1)输入预处理:从序列到Token
# 因为是对token维度做计算,所以需要将x reshape 成 (batch_size * seq_len,hidden_dim)
hidden_states = x.view(-1,hidden_dim)
  • 目的:MoE按​​token级别​​分配专家,需将序列拆分为单个token(去除seq_len维度)。
  • 维度变化:(batch_size, seq_len, hidden_dim)→ (batch_size*seq_len, hidden_dim)(2,4,16)(8,16))。
2)路由层输出:选Top-K专家及权重

路由层MoERouter的输出有4个变量,核心是​​为每个token标记Top-K专家及对应权重​​:

a.router_logits:原始专家分数
# 假设expert_number = 8 , top_k = 2
router_logits = self.gate(x)  # (batch_size * seq_len,expert_number)
  • 维度​​:(token_number, expert_number)→ (8,2)

  • ​意义​​:每个token对2个专家的原始未归一化分数(gate层是Linear(16,2))。

注意:token_number(即展平后的token总数)是整个流程的“核心锚点”——所有与“token级操作”相关的变量都直接或间接依赖于它。

token_number是​​输入序列中所有token的总数​​,由输入xbatch_sizeseq_len相乘得到:

token_{number} = batch_{size} \times seq_{len}

在测试用例中,x.shape=(2,4,16)→ token_num=2×4=8

它的物理意义是:​​将序列拆分为独立的token后,总共有多少个需要处理的“单元”​​

b.router_weights:Top-K专家的归一化权重
# 假设expert_number = 8 , top_k = 2
router_logits = self.gate(x)  # (batch_size * seq_len,expert_number)

# 计算每一个专家的概率
router_probs = F.softmax(router_logits,dim = 1, dtype = torch.float)

# 计算top_k专家的输出
# 注意 : top_k 是可以反向传播的
router_weights , selected_experts_indices = torch.topk(
    router_probs,
    self.top_k,
    dim = -1
)
# 输出的两个元素router_weights , selected_experts_indices的维度都是 (batch_size * seq_len,top_k)

# 对router_weights重新做归一化
router_weights = router_weights / router_weights.sum(
    dim = -1,keepdim=True
)

  • 维度​​:均为(token_number, top_k)→ (8,2)

  • 意义​​:

    • router_weights:每个token选的Top-2专家的​​归一化权重​​(如某token的权重是[0.7, 0.3],表示70%依赖专家0,30%依赖专家1)。

    • selected_experts_indices:每个token选的Top-2专家的​​索引​​(如[0,1]表示选专家0和1)。

c.expert_mask:专家-Token映射的三维掩码
expert_mask = F.one_hot(
    selected_experts_indices,
    num_classes=self.expert_number
)
### 很重要 : 输出的维度是 (batch_size * seq_len,top_k,expert_number)

expert_mask = expert_mask.permute(2,1,0)
# 输出维度变成了 (expert_number,top_k,batch_size * seq_len)
  • ​第一次变换​​:one_hotselected_experts_indices(8,2))转换为(8,2,2)——每个token的Top-2专家索引变成​​one-hot向量​​(如token0选专家0和1,则对应\begin{bmatrix} 1 & 0\\ 0& 1 \end{bmatrix})。

  • ​第二次变换​​:permute(2,1,0)调整维度顺序为(expert_number, top_k, token_number)→ (2,2,8)

  • ​意义​​:快速查询​​每个专家负责的Token​​(如专家0的信息是expert_mask[0]=(2,8),表示专家0的Top-2专家位置对应8个token的掩码)。

说明:

1.permute()操作:permute(*dims)是 PyTorch 中 Tensor 的一个方法,它的作用是​​重新排列张量的维度​​,它不会改变张量中的数据,只是改变了我们“看待”这些数据的方式,即数据的视图(view)。参数 dims指定了新维度的顺序。例如,tensor.permute(2, 0, 1)会将原始张量的第2个维度变为新的第0个维度,原始的第0个维度变为新的第1个维度,原始的第1个维度变为新的第2个维度。

2.为什么要这么做?用处是什么?

进行 permute(2, 1, 0)操作的主要目的是​​改变数据的视角,使其更适合后续以专家为中心 (expert-centric) 的计算​​。

3)专家计算:累积加权结果

SparseMoE的核心是遍历每个专家,收集其负责的Token,处理后加权融合到最终的结果:

final_hidden_states = torch.zeros(
            (batch_size * seq_len,hidden_dim),
             dtype = hidden_states.dtype,
             device = hidden_states.device
        )

        # 遍历每一个专家,把选中的这个专家的token的hidden_states加到final_hidden_states中
        # 例如:expert-0可能有100个token选中了:
        # token 的总数是 batch_size * seq_len
        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]

            # expert_masks shape is (expert_number,top_k,batch_size * seq_len)
            current_expert_mask = expert_masks[expert_idx]
            # 因为已经将expert_number取出来了,所以current_expert_mask就是一个二维的值了
            # current_expert_mask shape is (top_k,batch_size * seq_len)

            router_weights_idx , top_x = torch.where(current_expert_mask)
            """
            torch.where()的用法:
            where 返回矩阵中为true的行、列位置,这里分别代表top_k中第几个和tokens中的第几个
            """
            # router_weights_idx 是 0 or 1
            # 假设top_k是2,那么idx是代表在top_k里面的第几个(one-hot编码中的0 or 1)
            # router_weights_idx 表示这个token是作为当前专家的top1还是top2

            # top_x 是 token 在 batch_size * seq_len 中的位置索引
            # 例如对于 batch_size = 2,seq_len = 4 的输入:
            # top_x 的值范围是 0 - 7 , 表示在展平后的8个 token 中的位置
            # router_weights_idx 和 top_x 都是一维的值
            # router_weights_idx 是用来选择weight的
            # top_x 是用来选 hidden_states的

            # hidden_states shape is (batch_size * seq_len,hidden_dim)
            current_state = hidden_states.unsqueeze(0)[:,top_x,:].reshape(-1,hidden_dim)
            # hidden_states在做完unsqueeze(0)之后的维度为 (1,batch_size * seq_len,hidden_dim)
            # [:,top_x,:] 后的维度为 (1,top_x,hidden_dim)
            # 最终的current_state shape is (selected_token_number,hidden_dim)

            current_state = expert_layer(current_state)

            # router_weights shape is (batch_size * seq_len,top_k)
            # top_x 对应的是token的维度
            current_token_router_weight = router_weights[top_x,router_weights_idx]
            # 最终的current_token_router_weight 的维度是 (selected_token_number)
            current_token_router_weight = current_token_router_weight.unsqueeze(-1)
            # 现在最终的current_token_router_weight 的维度就变成了 (selected_token_number,1)

            current_hidden_states = current_state * current_token_router_weight
            # current_state shape (selected_token_number,hidden_dim)
            # current_token_router_weight shape (selected_token_number,1)
            ### 注意 : 这里有广播 ###
            # 最终的输出current_hidden_states shape is (selected_token_number,hidden_dim)

            final_hidden_states.index_add_(
                0,
                top_x,
                current_hidden_states.to(hidden_states.dtype)
            )
a. current_expert_mask:当前专家的Token掩码
current_expert_mask = expert_masks[expert_idx]
  • 维度​​:(top_k, token_number)→ (2,8)

  • ​意义​​:对于当前专家(如专家0),current_expert_mask的每一行对应Top-K中的一个位置(0=Top1,1=Top2),每一列对应一个Token(0~7)——True表示该Token被选入当前专家的Top-K。

b. torch.where():定位选中的Token的权重位置
router_weights_idx, top_x = torch.where(current_expert_mask)
  • 维度​​:均为(selected_token_number,)(如专家0选了2个Token,则为(2,))。

  • ​意义​​:
    • router_weights_idx:选中的Token在当前专家Top-K中的​​位置​​(0=Top1,1=Top2)。

    • top_x:选中的Token在token_num中的​​索引​​(如[0,1]表示选Token0和1)

c. current_state:当前专家处理的Token Hidden States
current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(-1, 16)  # (2,16)
  • 变换步骤​​:

    1. hidden_states.unsqueeze(0):将(8,16)变为(1,8,16)(增加batch_size维度,方便切片)。

    2. [:, top_x, :]:按top_x(如[0,1])选Top-K个Token,得到(1,2,16)

    3. reshape(-1,16):恢复为(2,16)——当前专家负责的Token的Hidden_States。

  • ​目的​​:将选中的Token集合转换为专家层的输入格式((M, hidden),M是选中的Token数)。

d. current_token_router_weight:Token对应的专家权重
current_token_router_weight = router_weights[top_x, router_weights_idx].unsqueeze(-1)  # (2,1)
  • 维度变化​​:router_weights[top_x, router_weights_idx](2,)unsqueeze(-1)变为(2,1)

  • 意义​​:

    • router_weights[top_x, router_weights_idx]:按top_x(Token索引)和router_weights_idx(Top-K位置)取出对应的权重(如Token0的权重是router_weights[0,0]=0.7,Token1的权重是router_weights[1,1]=0.3)。

    • unsqueeze(-1):将权重从(2,)变为(2,1),以便与current_state(2,16))广播相乘。

e. current_hidden_states:专家处理后的加权结果
current_state = expert_layer(current_state)  # (2,16) → 专家层处理
current_hidden_states = current_state * current_token_router_weight  # (2,16) * (2,1) → (2,16)

意义​​:用专家层处理选中的Token,再乘以对应权重(实现“加权专家输出”)。

f. index_add_ :累加到最终结果
final_hidden_states.index_add_(0, top_x, current_hidden_states)
  • 作用​​:将current_hidden_states(M,16))按top_x(Token索引)加到final_hidden_states(8,16))的对应行。

  • ​意义​​:每个Token可能被多个专家处理(Top-K个),index_add_累积所有专家的加权输出,得到该Token的最终结果。

4)结果还原:从Token到序列
# 把final_hidden_states还原到原来的shape
final_hidden_states = final_hidden_states.reshape(batch_size,seq_len,hidden_dim)

目的​​:将处理后的Token结果恢复为原始序列形状(batch, seq_len, hidden)

(4)关键操作的意义总结

操作示例维度变换意义
view(-1,hidden_dim)(2,4,16)(8,16)拆分为单个Token,便于路由层处理
unsqueeze(0)(8,16)(1,8,16)增加维度,方便用[:, top_x, :]选多个Token
permute(2,1,0)(8,2,2)(2,2,8)调整维度顺序,快速查询每个专家的Token掩码
torch.where(2,8)(2,)+(2,)定位当前专家负责的Token及对应的权重位置
index_add_(8,16)+= (M,16)累积所有专家对Token的加权输出,实现稀疏激活

3.课后作业

把 MOE 应用到上一次的 GPT2.0(GPT-2.0 解剖指南:从原理到实践,手把手实现你的第一个 Transformer 模型 ✨) 中,也就是替换掉原来的 FFN层,注意这里负载均衡 loss 要包含每一层的 MOE 的 router_logits

留个悬念,参考答案下篇博客会专门讲一下是如何替换的,并与原来的GPT-2.0对比有哪些不同之处,各位小伙伴可以期待一下哈!

四、ShareExpert SparseMoE (DeepSeek版本)


备注:这里只是参考DeepSeek MoE(论文原文:DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models)的思想,写的一个共享expert的MoE网络,有一定的简化,但是可以方便理解训练过程。


1.ShareExpert SparseMoE

和《三、SparseMoE,MoE LLM,了解现代MoE大模型怎么训练》的Sparse MoE区别是,这里多了一个shared experts的模型,这个模型是所有token共享的,也就是说,所有token都过这个shared experts模型,然后每个token会用计算的Router权重,来选择top_K个专家,然后和共享的专家的输出一起加权求和。

具体的结构图为:

2.PyTorch代码实现及详解

前面一部分代码就是前面Sparse MoE的代码,这里只是将后续的共享专家加入

import torch
import torch.nn as nn
import torch.nn.functional as F

class MOEConfig:
    def __init__(self,hidden_dim,expert_number,top_k,shared_experts_number = 2):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_expert_number = shared_experts_number

class BasicExpert(nn.Module):
    def __init__(self,feature_in,feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in,feature_out)

    def forward(self,x):
        return self.fc(x)

class MoERouter(nn.Module):
    def __init__(self,config):
        super().__init__()  
        self.gate = nn.Linear(config.hidden_dim,config.expert_number)
        # 但是后面只会选择top_k个专家

        self.expert_number = config.expert_number
        self.top_k = config.top_k

    def forward(self,x):
        # 假设expert_number = 8 , top_k = 2
        router_logits = self.gate(x)  # (batch_size * seq_len,expert_number)

        # 计算每一个专家的概率
        router_probs = F.softmax(router_logits,dim = 1, dtype = torch.float)

        # 计算top_k专家的输出
        # 注意 : top_k 是可以反向传播的
        router_weights , selected_experts_indices = torch.topk(
            router_probs,
            self.top_k,
            dim = -1
        )
        # 输出的两个元素router_weights , selected_experts_indices的维度都是 (batch_size * seq_len,top_k)

        # 对router_weights重新做归一化
        router_weights = router_weights / router_weights.sum(
            dim = -1,keepdim=True
        )
        router_weights = router_weights.to(x.dtype)

        expert_mask = F.one_hot(
            selected_experts_indices,
            num_classes=self.expert_number
        )
        ### 很重要 : 输出的维度是 (batch_size * seq_len,top_k,expert_number)

        expert_mask = expert_mask.permute(2,1,0)
        # 输出维度变成了 (expert_number,top_k,batch_size * seq_len)

        return router_logits,router_weights,selected_experts_indices,expert_mask
        # router_logits shape is (batch_size * seq_len,expert_number)
        # router_weights shape is (batch_size * seq_len,top_k)
        # selected_experts_indices shape is (batch_size * seq_len,top_k)
        # expert_mask shape is (expert_number,top_k,batch_size * seq_len)

class SparseMoE(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config

        self.top_k = config.top_k
        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number

    # 初始化专家
        self.experts = nn.ModuleList(
            BasicExpert(
                config.hidden_dim,
                config.hidden_dim
            ) for _ in range(config.expert_number)
        )
        self.router = MoERouter(config)

    def forward(self,x):
        # x shape is (batch_size,seq_len,hidden_dim)
        batch_size,seq_len,hidden_dim = x.size()

        # 因为是对token维度做计算,所以需要将x reshape 成 (batch_size * seq_len,hidden_dim)
        hidden_states = x.view(-1,hidden_dim)

        # 做相关的专家的计算
        router_logits,router_weights,selected_experts_indices,expert_masks = self.router(
            hidden_states
        )

        # expert_masks shape is (expert_number,top_k,batch_size * seq_len)
        # 最终的维度肯定是 (batch_size * seq_len,hidden_dim)
        final_hidden_states = torch.zeros(
            (batch_size * seq_len,hidden_dim),
             dtype = hidden_states.dtype,
             device = hidden_states.device
        )

        # 遍历每一个专家,把选中的这个专家的token的hidden_states加到final_hidden_states中
        # 例如:expert-0可能有100个token选中了:
        # token 的总数是 batch_size * seq_len
        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]

            # expert_masks shape is (expert_number,top_k,batch_size * seq_len)
            current_expert_mask = expert_masks[expert_idx]
            # 因为已经将expert_number取出来了,所以current_expert_mask就是一个二维的值了
            # current_expert_mask shape is (top_k,batch_size * seq_len)

            router_weights_idx , top_x = torch.where(current_expert_mask)
            """
            torch.where()的用法:
            where 返回矩阵中为true的行、列位置,这里分别代表top_k中第几个和tokens中的第几个
            """
            # router_weights_idx 是 0 or 1
            # 假设top_k是2,那么idx是代表在top_k里面的第几个(one-hot编码中的0 or 1)
            # router_weights_idx 表示这个token是作为当前专家的top1还是top2

            # top_x 是 token 在 batch_size * seq_len 中的位置索引
            # 例如对于 batch_size = 2,seq_len = 4 的输入:
            # top_x 的值范围是 0 - 7 , 表示在展平后的8个 token 中的位置
            # router_weights_idx 和 top_x 都是一维的值
            # router_weights_idx 是用来选择weight的
            # top_x 是用来选 hidden_states的

            # hidden_states shape is (batch_size * seq_len,hidden_dim)
            current_state = hidden_states.unsqueeze(0)[:,top_x,:].reshape(-1,hidden_dim)
            # hidden_states在做完unsqueeze(0)之后的维度为 (1,batch_size * seq_len,hidden_dim)
            # [:,top_x,:] 后的维度为 (1,top_x,hidden_dim)
            # 最终的current_state shape is (selected_token_number,hidden_dim)

            current_state = expert_layer(current_state)

            # router_weights shape is (batch_size * seq_len,top_k)
            # top_x 对应的是token的维度
            current_token_router_weight = router_weights[top_x,router_weights_idx]
            # 最终的current_token_router_weight 的维度是 (selected_token_number)
            current_token_router_weight = current_token_router_weight.unsqueeze(-1)
            # 现在最终的current_token_router_weight 的维度就变成了 (selected_token_number,1)

            current_hidden_states = current_state * current_token_router_weight
            # current_state shape (selected_token_number,hidden_dim)
            # current_token_router_weight shape (selected_token_number,1)
            ### 注意 : 这里有广播 ###
            # 最终的输出current_hidden_states shape is (selected_token_number,hidden_dim)

            final_hidden_states.index_add_(
                0,
                top_x,
                current_hidden_states.to(hidden_states.dtype)
            )
        # 把final_hidden_states还原到原来的shape
        final_hidden_states = final_hidden_states.reshape(batch_size,seq_len,hidden_dim)

        return final_hidden_states,router_logits
    # shape is (batch_size * seq_len,expert_number)

class SharedExpertMOE(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config
        self.routed_experts_moe = SparseMoE(config)
        self.shared_experts = nn.ModuleList(
            [
                BasicExpert(self.config.hidden_dim,self.config.hidden_dim) for _ in range(config.expert_number)
            ]
        )

    def forward(self,x):
        # x shape is (batch_size,seq_len,hidden_dim)
        batch_size,seq_len,hidden_dim = x.size()

        shared_experts_output_list = [
            expert(x) for expert in self.shared_experts
        ]
        shared_expert_output = torch.stack(
            shared_experts_output_list,
            dim = 0
        )
        # 输出的shape is (shared_experts_number,batch_size,seq_len,hidden_dim)

        shared_expert_out = shared_expert_output.sum(dim=0,keepdim=False)
        # shape is (batch_size,seq_len,hidden_dim)
        spares_moe_out,router_logits = self.routed_experts_moe(
            x
        )

        output = shared_expert_out + spares_moe_out

        return output , router_logits
    
def test_shared_expert_moe():
    x = torch.rand(2,4,16)
    config = MOEConfig(16,2,2)
    shared_expert_moe = SharedExpertMOE(config)
    out = shared_expert_moe(x)
    print(out[0].shape,out[1].shape)

test_shared_expert_moe()

3.ShareExpert SparseMoE训练验证代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class MOEConfig:
    def __init__(self,hidden_dim,expert_number,top_k,shared_experts_number = 2):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_expert_number = shared_experts_number

class BasicExpert(nn.Module):
    def __init__(self,feature_in,feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in,feature_out)

    def forward(self,x):
        return self.fc(x)

class MoERouter(nn.Module):
    def __init__(self,config):
        super().__init__()  
        self.gate = nn.Linear(config.hidden_dim,config.expert_number)
        # 但是后面只会选择top_k个专家

        self.expert_number = config.expert_number
        self.top_k = config.top_k

    def forward(self,x):
        # 假设expert_number = 8 , top_k = 2
        router_logits = self.gate(x)  # (batch_size * seq_len,expert_number)

        # 计算每一个专家的概率
        router_probs = F.softmax(router_logits,dim = 1, dtype = torch.float)

        # 计算top_k专家的输出
        # 注意 : top_k 是可以反向传播的
        router_weights , selected_experts_indices = torch.topk(
            router_probs,
            self.top_k,
            dim = -1
        )
        # 输出的两个元素router_weights , selected_experts_indices的维度都是 (batch_size * seq_len,top_k)

        # 对router_weights重新做归一化
        router_weights = router_weights / router_weights.sum(
            dim = -1,keepdim=True
        )
        router_weights = router_weights.to(x.dtype)

        expert_mask = F.one_hot(
            selected_experts_indices,
            num_classes=self.expert_number
        )
        ### 很重要 : 输出的维度是 (batch_size * seq_len,top_k,expert_number)

        expert_mask = expert_mask.permute(2,1,0)
        # 输出维度变成了 (expert_number,top_k,batch_size * seq_len)

        return router_logits,router_weights,selected_experts_indices,expert_mask
        # router_logits shape is (batch_size * seq_len,expert_number)
        # router_weights shape is (batch_size * seq_len,top_k)
        # selected_experts_indices shape is (batch_size * seq_len,top_k)
        # expert_mask shape is (expert_number,top_k,batch_size * seq_len)

class SparseMoE(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config

        self.top_k = config.top_k
        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number

    # 初始化专家
        self.experts = nn.ModuleList(
            BasicExpert(
                config.hidden_dim,
                config.hidden_dim
            ) for _ in range(config.expert_number)
        )
        self.router = MoERouter(config)

    def forward(self,x):
        # x shape is (batch_size,seq_len,hidden_dim)
        batch_size,seq_len,hidden_dim = x.size()

        # 因为是对token维度做计算,所以需要将x reshape 成 (batch_size * seq_len,hidden_dim)
        hidden_states = x.view(-1,hidden_dim)

        # 做相关的专家的计算
        router_logits,router_weights,selected_experts_indices,expert_masks = self.router(
            hidden_states
        )

        # expert_masks shape is (expert_number,top_k,batch_size * seq_len)
        # 最终的维度肯定是 (batch_size * seq_len,hidden_dim)
        final_hidden_states = torch.zeros(
            (batch_size * seq_len,hidden_dim),
             dtype = hidden_states.dtype,
             device = hidden_states.device
        )

        # 遍历每一个专家,把选中的这个专家的token的hidden_states加到final_hidden_states中
        # 例如:expert-0可能有100个token选中了:
        # token 的总数是 batch_size * seq_len
        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]

            # expert_masks shape is (expert_number,top_k,batch_size * seq_len)
            current_expert_mask = expert_masks[expert_idx]
            # 因为已经将expert_number取出来了,所以current_expert_mask就是一个二维的值了
            # current_expert_mask shape is (top_k,batch_size * seq_len)

            router_weights_idx , top_x = torch.where(current_expert_mask)
            """
            torch.where()的用法:
            where 返回矩阵中为true的行、列位置,这里分别代表top_k中第几个和tokens中的第几个
            """
            # router_weights_idx 是 0 or 1
            # 假设top_k是2,那么idx是代表在top_k里面的第几个(one-hot编码中的0 or 1)
            # router_weights_idx 表示这个token是作为当前专家的top1还是top2

            # top_x 是 token 在 batch_size * seq_len 中的位置索引
            # 例如对于 batch_size = 2,seq_len = 4 的输入:
            # top_x 的值范围是 0 - 7 , 表示在展平后的8个 token 中的位置
            # router_weights_idx 和 top_x 都是一维的值
            # router_weights_idx 是用来选择weight的
            # top_x 是用来选 hidden_states的

            # hidden_states shape is (batch_size * seq_len,hidden_dim)
            current_state = hidden_states.unsqueeze(0)[:,top_x,:].reshape(-1,hidden_dim)
            # hidden_states在做完unsqueeze(0)之后的维度为 (1,batch_size * seq_len,hidden_dim)
            # [:,top_x,:] 后的维度为 (1,top_x,hidden_dim)
            # 最终的current_state shape is (selected_token_number,hidden_dim)

            current_state = expert_layer(current_state)

            # router_weights shape is (batch_size * seq_len,top_k)
            # top_x 对应的是token的维度
            current_token_router_weight = router_weights[top_x,router_weights_idx]
            # 最终的current_token_router_weight 的维度是 (selected_token_number)
            current_token_router_weight = current_token_router_weight.unsqueeze(-1)
            # 现在最终的current_token_router_weight 的维度就变成了 (selected_token_number,1)

            current_hidden_states = current_state * current_token_router_weight
            # current_state shape (selected_token_number,hidden_dim)
            # current_token_router_weight shape (selected_token_number,1)
            ### 注意 : 这里有广播 ###
            # 最终的输出current_hidden_states shape is (selected_token_number,hidden_dim)

            final_hidden_states.index_add_(
                0,
                top_x,
                current_hidden_states.to(hidden_states.dtype)
            )
        # 把final_hidden_states还原到原来的shape
        final_hidden_states = final_hidden_states.reshape(batch_size,seq_len,hidden_dim)

        return final_hidden_states,router_logits
    # shape is (batch_size * seq_len,expert_number)

class SharedExpertMOE(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config
        self.routed_experts_moe = SparseMoE(config)
        self.shared_experts = nn.ModuleList(
            [
                BasicExpert(self.config.hidden_dim,self.config.hidden_dim) for _ in range(config.expert_number)
            ]
        )

    def forward(self,x):
        # x shape is (batch_size,seq_len,hidden_dim)
        batch_size,seq_len,hidden_dim = x.size()

        shared_experts_output_list = [
            expert(x) for expert in self.shared_experts
        ]
        shared_expert_output = torch.stack(
            shared_experts_output_list,
            dim = 0
        )
        # 输出的shape is (shared_experts_number,batch_size,seq_len,hidden_dim)

        shared_expert_out = shared_expert_output.sum(dim=0,keepdim=False)
        # shape is (batch_size,seq_len,hidden_dim)
        spares_moe_out,router_logits = self.routed_experts_moe(
            x
        )

        output = shared_expert_out + spares_moe_out

        return output , router_logits


def switch_load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
    """
    计算 Switch Transformers 的负载均衡损失
    
    Args:
        router_logits: shape [batch_size * sequence_length, num_experts]
        num_experts: 专家数量
    
    Returns:
        total_loss: 总损失 = auxiliary_loss + z_loss
    """
    # 计算路由概率
    router_probs = torch.softmax(router_logits, dim=-1)  # [b*s, num_experts]
    
    # 获取每个token的最优专家
    _, selected_experts = torch.topk(router_probs, k=2, dim=-1) 
    
    # 创建one-hot矩阵表示选中的专家
    mask = torch.nn.functional.one_hot(selected_experts, num_experts).float() 
    
    # 计算每个专家的期望负载 (理想情况下应该是 1/num_experts)
    expected_load = torch.ones_like(router_probs) / num_experts
    
    # 计算实际负载 (每个专家处理的token数量除以总token数量)
    # 在batch维度上计算平均值
    actual_load = mask.mean(dim=0)
    
    # 计算auxiliary loss
    # 这会惩罚负载分布与期望负载的差异
    aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts
    
    # 计算z_loss (可选)
    # 这会惩罚过大的路由logits
    z_loss = torch.mean(torch.square(router_logits))
    z_loss_weight = 0.001  # 可调整的超参数
    
    # 总损失
    total_loss = aux_loss + z_loss * z_loss_weight
    
    return total_loss

def test_moe_training():
    # Create a simple dataset
    batch_size = 32
    seq_len = 16
    hidden_dim = 32
    num_batches = 100
    
    # Initialize model and optimizer
    config = MOEConfig(hidden_dim=hidden_dim, 
                      expert_number=4,
                      top_k=2,
                      shared_experts_number=2)
    model = SharedExpertMOE(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    model.train()
    for batch in range(num_batches):
        # Generate random input data
        x = torch.randn(batch_size, seq_len, hidden_dim)
        target = torch.randn(batch_size, seq_len, hidden_dim)
        
        # Forward pass
        output, router_logits = model(x)

        # Compute losses
        # MSE loss for prediction
        mse_loss = F.mse_loss(output, target)
        
        aux_loss = switch_load_balancing_loss(router_logits, config.expert_number)
        # Combined loss
        total_loss = mse_loss + 0.01 * aux_loss
        
        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
                  f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")

# Run the training test
test_moe_training()

4.注意事项提示:

1) 为什么需要返回router_logits?

因为我们希望每个专家处理的token数量差不多,尽量保持均衡,这也是switch transformer里面提出来的那个损失函数(Switch Transformers的负载均衡损失)。

2)Loss Function的重构

原本就只有一个cross entropy的Loss,但是现在引入了Switch Transformers的负载均衡损失之后,整体训练时的Loss Fuction是两个Loss的加权之和

5.课后作业:【大模型面试经典题】:Top_K是如何实现反向传播的?

要理解MoE中​​Top-K操作的反向传播​​,核心是解决​​“原生Top-K不可导”​​的问题,并理清​​梯度从最终损失到路由层、专家层的传递路径​​。

(1)Top-K操作的反向传播挑战

原生torch.topk是​​硬选择操作​​——仅保留每个token对专家分数最高的top_k个值,未被选中的位置梯度直接置零。这种“一刀切”会导致两个问题:

​梯度消失​​:未被选中的专家无法获得梯度,难以学习如何提升分数以进入Top-K。

​路由策略僵化​​:模型无法微调“选哪些专家”的决策,因为Top-K的边界是固定的。

(2)解决方案:软权重与掩码

通过​​保留router_probs(Softmax后的概率)的可导性​​,并将Top-K转化为​​软权重加权和​​,规避了原生Top-K的不可导问题。关键设计如下:

1. 不依赖“硬索引”,而是“软权重”

路由层MoERouter的输出不仅是Top-K的​​专家索引​​(selected_experts_indices),更重要的是:

router_probs​:每个token对所有专家的Softmax概率(可导);

router_weights​:Top-K专家的概率归一化结果(可导);

expert_mask​:专家- token的掩码(不可导,但不影响梯度传递)。

这些输出将“硬选择”转化为“软权重”,使得梯度可以通过router_probsrouter_weights传递。

2. 梯度传递的核心:router_weightsrouter_probs的关联

router_weightsrouter_probs的Top-K部分归一化后的结果:

router_weights = router_probs.gather(dim=-1, index=selected_experts_indices)  # 取Top-K的概率值
router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)     # 归一化

虽然torch.topk不可导,但router_weights的梯度​​仅传递给router_probs的Top-K位置​​(未被选中的位置梯度为零)。具体来说:

对于router_weights中的元素\omega _{ij}(token i对专家j的Top-K权重),其梯度\frac{\partial \omega _{ij}}{\partial p_{ik}}p_{ik}是token i对专家k的Softmax概率)满足:

\frac{\partial \omega _{ij}}{\partial p_{ik}} = \delta _{jk}\cdot \frac{\sum_{m = 1}^{K}p_{im} - p_{ij}}{\left ( \sum_{m = 1}^{K}p_{im} \right )^{2}}

其中\delta _{jk}是克罗内克函数(仅当j=k时为1,否则为0)。这意味着,\omega _{ij}的梯度仅影响p_{ij}(token i对专家j的Softmax概率),其他p_{ik}k≠j)的梯度为零。

3. 梯度从router_weights到路由层的传递

router_probsrouter_logitsgate层的输出)的Softmax结果,其梯度通过Softmax的导数公式传递给router_logits

\frac{\partial p_{ij}}{\partial l_{ij}} = p_{ij}\cdot (1 - p_{ij})

\frac{\partial p_{ij}}{\partial l_{ik}} = -p_{ij}\cdot p_{ik} (k \neq j)

(3)完整反向传播路径:从损失到所有参数

以最终损失loss为例,梯度传递路径如下:

1. 从lossfinal_hidden_states

final_hidden_states是专家处理后的加权 sum,其梯度直接来自loss

其中h_itoken_i的最终隐藏状态。

2. 从final_hidden_states到专家参数

每个专家的输出e_j(h)BasicExpertfc层结果)乘以权重\omega _{ij}后,累积到final\, hidden\, states_i。因此:

专家参数(如BasicExpert.fc.weight)的梯度来自所有其负责的token的e_j(h_i)的梯度之和。

3. 从final_hidden_statesrouter_weights

final\, hidden\, states_i的梯度会传递给\omega _{ij}(token i对专家j的Top-K权重):

其中\frac{\partial h_i}{\partial \omega _{ij}}仅在\omega _{ij}\neq 0(即token i被选入专家j的Top-K)时非零。

4. 从router_weightsrouter_probsrouter_logits

如前所述,\omega _{ij}的梯度仅传递给router_probs的Top-K位置,再通过Softmax导数传递给router_logits

5. 从router_logitsgate层参数

router_logits的梯度传递给gate层的权重和偏置,更新路由策略:

五、致谢

以上仅仅是本人对于MoE的一些拙见,如果小伙伴们有不同的理解,希望能够在评论区积极讨论哈!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值