MoE专家模块Demo

前言

随着MoE越来越火热,MoE本质就是将Transformer中的FFN层替换成了MoE-layer,其中每个MoE-Layer由一个gate和若干个experts组成。这里gate和每个expert都可以理解成是nn.linear形式的神经网络。既然如此,本篇文章将结合transformers结构构建一个MoE的demo供大家学习。该源码可直接使用。

一、MoE原理与设计原则

expert:术业有专攻。假设我的输入数据是“我爱吃炸鸡”,在原始的Transformer中,我们把这5个token送去一个FFN层做处理。但是现在我们发现这句话从结构上可以拆成“主语-我”,“谓语-爱吃”,“宾语-炸鸡”,秉持着术业有专攻的原则,我把原来的1个FFN拆分成若干个expert,分别用来单独解析“主语”,“谓语”,“宾语”,这样可能会达到更好的效果。

gate:那么我怎么知道要把哪个token送去哪个expert呢?很简单,我再训练一个gate神经网络,让它帮我判断就好了。

当然,这里并不是说expert就是用来解析主谓宾,只是举一个例子说明:不同token代表的含义不一样,因此我们可以用不同expert来对它们做解析。除了训练上也许能达到更好的效果外,MoE还能帮助我们在扩大模型规模的同时保证计算量是非线性增加的(因为每个token只用过topK个expert,不用过全量expert),这也是我们说MoE-layer是稀疏层的原因。

最后需要注意的是,在之前的表述中,我们说expert是从FFN层转变而来的,这很容易让人错理解成expert就是对FFN的平均切分,实际上你可以任意指定每个expert的大小,每个expert甚至可以>=原来单个FFN层,这并不会改变MoE的核心思想:token只发去部分expert时的计算量会小于它发去所有expert的计算量。

引用:https://mp.weixin.qq.com/s/76a-7fDJumv6iB08L2BUKg

二、构建完整transformers与MoE集成模块

这个模块定义了一个名为Block的PyTorch模块,代表了一个混合专家Transformer块,包括多头自注意力和计算MoE部分(SparseMoE)。其源码如下:

class Block(nn.Module):
   """Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """
   def __init__(self, n_embed, n_head, num_experts, top_k):
       super().__init__()
       self.sa = nn.MultiheadAttention(n_embed, n_head)  # 我直接调用官网的attention方法
       self.smoe = SparseMoE(n_embed, num_experts, top_k)
       self.ln1 = nn.LayerNorm(n_embed)
       self.ln2 = nn.LayerNorm(n_embed)

   def forward(self, x):
       qkv = self.ln1(x)
       x = x + self.sa(qkv,qkv,qkv)[0]  # 并不包含FFN结构
       x = x + self.smoe(self.ln2(x))
       return x

这段代码的解读:
__init__方法:在初始化方法中,定义了模块的结构。模块包含了一个nn.MultiheadAttention实例sa用于多头自注意力计算,我直接使用pytorch官网的方法模块,一个SparseMoE实例smoe用于稀疏专家计算,以及两个nn.LayerNorm实例ln1和ln2用于层归一化。

forward方法:在前向传播方法中,首先对输入张量x进行层归一化处理,然后通过多头自注意力模块对处理后的张量进行注意力计算,并将注意力计算结果与原始输入张量相加。接着,将相加后的张量通过稀疏专家模块进行计算,再次与原始输入张量相加。最后返回处理后的张量。

总体来说,这段代码实现了一个混合专家Transformer块,结合了多头自注意力和稀疏专家计算。在前向传播过程中,通过多头自注意力和稀疏专家计算两部分对输入张量进行处理,并保持了张量的维度一致。

注: x = x + self.smoe(self.ln2(x))使用了类似残差方法

三、专家模块定义

专家模块定义了一个名为Expert的PyTorch模块,代表了一个专家模块,用于对输入进行线性变换和非线性变换。

class Expert(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        dropout=0.1
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

四、路由门控模块

门控网络,也称为路由,确定哪个专家网络接收来自多头注意力的 token 的输出。举个例子解释路由的机制,假设有 4 个专家,token 需要被路由到前 2 个专家中。首先需要通过线性层将 token 输入到门控网络中。该层将对应于(Batch size,Tokens,n_embed)的输入张量从(2,4,32)维度,投影到对应于(Batch size、Tokens,num_expert)的新形状:(2、4,4)。其中 n_embed 是输入的通道维度,num_experts 是专家网络的计数。

class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        # add noise
        self.noise_linear = nn.Linear(n_embed, num_experts)

    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        # Noise logits
        noise_logits = self.noise_linear(mh_output)

        # Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits) * F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)  # 在4个专家中选择最好的专家
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

五、稀疏MoE集成模块

通过路由获得indices在对每个专家循环获得对应mask,挑选有用的信息不断叠加而获得最终结果,我已改成伪代码如下说明:

# 1. 输入进入router得到两个输出
gating_output, indices = router(x)
# 2.初始化全零矩阵,后续叠加为最终结果
final_output = zeros_like(x)
# 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平
flat_x = flatten(x)
flat_gating_output = flatten(gating_output)
# 4. 对每个专家进行操作
for each expert in experts:
    # 5. 查看当前专家对哪些tokens在前top_k
    expert_mask = check_top_k_indices(indices, expert)
    # 6. 获取当前专家作用的token输入
    expert_input = select_expert_input(flat_x, expert_mask)
    # 7. 将token输入经过专家处理得到输出
    expert_output = expert(expert_input)
    # 8. 计算当前专家对作用的token的权重分数
    gating_scores = calculate_gating_scores(flat_gating_output, expert_mask)
    # 9. 将expert输出乘上权重分数
    weighted_output = expert_output * gating_scores
    # 10. 将结果叠加到最终输出中
    final_output += weighted_output
return final_output

总结解释:
这段代码实现了一个稀疏专家(SparseMoE)模块,结合了一个路由器(router)和多个专家(experts)。
输入首先通过路由器得到两个输出:一个是门控输出(gating_output),一个是索引(indices)。
然后初始化一个全零矩阵作为最终输出。
对每个专家进行操作,根据索引找出每个专家对哪些token起作用,然后将这些token输入到对应的专家中进行处理,根据门控输出的权重分配,将专家处理后的输出加权叠加到最终输出中。
最终返回加权叠加后的最终输出。

其源码如下:

class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        # 1. 输入进入router得到两个输出
        gating_output, indices = self.router(x)  # x.shape=[2,6,64] gating_output.shape=[2,6,4]
        # 2.初始化全零矩阵,后续叠加为最终结果
        final_output = torch.zeros_like(x)  # final_output.shape=[2,6,64]

        # 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平
        flat_x = x.view(-1, x.size(-1))  # flat_x.shape=[12,64]
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))  # flat_gating_output.shape=[12,4]

        # 以每个专家为单位进行操作,即把当前专家处理的所有token都进行加权
        for i, expert in enumerate(self.experts):
            # 4. 对当前的专家(例如专家0)来说,查看其对所有tokens中哪些在前top2
            expert_mask = (indices == i).any(dim=-1)  # expert_mask.shape=[2,6]
            # 5. 展平操作
            flat_mask = expert_mask.view(-1)  # flat_mask=12
            # 如果当前专家是任意一个token的前top2
            if flat_mask.any():
                # 6. 得到该专家对哪几个token起作用后,选取token的维度表示
                expert_input = flat_x[flat_mask]  # 假设第0个专家选择了2个batch的7个维度为[7,64]
                # 7. 将token输入expert得到输出
                expert_output = expert(expert_input)  # 这个专家选择通道走了mlp结构,[7,64]

                # 8. 计算当前专家对于有作用的token的权重分数
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)  # 使用flat_mask选择第i个专家的权重
                # 9. 将expert输出乘上权重分数
                weighted_output = expert_output * gating_scores

                # 10. 循环进行做种的结果叠加,也就是越重要被专家选择越多就叠加越多
                final_output[expert_mask] += weighted_output.squeeze(1)  # 变成[7,64]

        return final_output

六、完整MoE的Demo

最后,给出即插即用的MoE完整Demo代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        dropout=0.1
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        # add noise
        self.noise_linear = nn.Linear(n_embed, num_experts)

    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        # Noise logits
        noise_logits = self.noise_linear(mh_output)

        # Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits) * F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)  # 在4个专家中选择最好的专家
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        # 1. 输入进入router得到两个输出
        gating_output, indices = self.router(x)  # x.shape=[2,6,64] gating_output.shape=[2,6,4]
        # 2.初始化全零矩阵,后续叠加为最终结果
        final_output = torch.zeros_like(x)  # final_output.shape=[2,6,64]

        # 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平
        flat_x = x.view(-1, x.size(-1))  # flat_x.shape=[12,64]
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))  # flat_gating_output.shape=[12,4]

        # 以每个专家为单位进行操作,即把当前专家处理的所有token都进行加权
        for i, expert in enumerate(self.experts):
            # 4. 对当前的专家(例如专家0)来说,查看其对所有tokens中哪些在前top2
            expert_mask = (indices == i).any(dim=-1)  # expert_mask.shape=[2,6]
            # 5. 展平操作
            flat_mask = expert_mask.view(-1)  # flat_mask=12
            # 如果当前专家是任意一个token的前top2
            if flat_mask.any():
                # 6. 得到该专家对哪几个token起作用后,选取token的维度表示
                expert_input = flat_x[flat_mask]  # 假设第0个专家选择了2个batch的7个维度为[7,64]
                # 7. 将token输入expert得到输出
                expert_output = expert(expert_input)  # 这个专家选择通道走了mlp结构,[7,64]

                # 8. 计算当前专家对于有作用的token的权重分数
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)  # 使用flat_mask选择第i个专家的权重
                # 9. 将expert输出乘上权重分数
                weighted_output = expert_output * gating_scores

                # 10. 循环进行做种的结果叠加,也就是越重要被专家选择越多就叠加越多
                final_output[expert_mask] += weighted_output.squeeze(1)  # 变成[7,64]

        return final_output

class Block(nn.Module):
    """Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """
    def __init__(self, n_embed, n_head, num_experts, top_k):
        super().__init__()
        self.sa = nn.MultiheadAttention(n_embed, n_head)  # 我直接调用官网的attention方法
        self.smoe = SparseMoE(n_embed, num_experts, top_k)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        qkv = self.ln1(x)
        x = x + self.sa(qkv,qkv,qkv)[0]  # 并不包含FFN结构
        x = x + self.smoe(self.ln2(x))
        return x


if __name__ == '__main__':
    # 假设分词64个embed与4个头
    n_embed, n_head = 64, 4
    dropout = 0.1
    # 假设4个专家与2个top_k
    num_experts, top_k = 4, 2

    # 假设数据是2个batch、6个分词,后面代码注释都已这些假设具体化
    inputs = torch.rand(2, 6, n_embed)
    M = Block(n_embed, n_head, num_experts, top_k)

    y = M(inputs)

    print(y.shape)



  • 9
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。
1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。
1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

tangjunjun-owen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值