前言
随着MoE越来越火热,MoE本质就是将Transformer中的FFN层替换成了MoE-layer,其中每个MoE-Layer由一个gate和若干个experts组成。这里gate和每个expert都可以理解成是nn.linear形式的神经网络。既然如此,本篇文章将结合transformers结构构建一个MoE的demo供大家学习。该源码可直接使用。
一、MoE原理与设计原则
expert:术业有专攻。假设我的输入数据是“我爱吃炸鸡”,在原始的Transformer中,我们把这5个token送去一个FFN层做处理。但是现在我们发现这句话从结构上可以拆成“主语-我”,“谓语-爱吃”,“宾语-炸鸡”,秉持着术业有专攻的原则,我把原来的1个FFN拆分成若干个expert,分别用来单独解析“主语”,“谓语”,“宾语”,这样可能会达到更好的效果。
gate:那么我怎么知道要把哪个token送去哪个expert呢?很简单,我再训练一个gate神经网络,让它帮我判断就好了。
当然,这里并不是说expert就是用来解析主谓宾,只是举一个例子说明:不同token代表的含义不一样,因此我们可以用不同expert来对它们做解析。除了训练上也许能达到更好的效果外,MoE还能帮助我们在扩大模型规模的同时保证计算量是非线性增加的(因为每个token只用过topK个expert,不用过全量expert),这也是我们说MoE-layer是稀疏层的原因。
最后需要注意的是,在之前的表述中,我们说expert是从FFN层转变而来的,这很容易让人错理解成expert就是对FFN的平均切分,实际上你可以任意指定每个expert的大小,每个expert甚至可以>=原来单个FFN层,这并不会改变MoE的核心思想:token只发去部分expert时的计算量会小于它发去所有expert的计算量。
引用:https://mp.weixin.qq.com/s/76a-7fDJumv6iB08L2BUKg
二、构建完整transformers与MoE集成模块
这个模块定义了一个名为Block的PyTorch模块,代表了一个混合专家Transformer块,包括多头自注意力和计算MoE部分(SparseMoE)。其源码如下:
class Block(nn.Module):
"""Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """
def __init__(self, n_embed, n_head, num_experts, top_k):
super().__init__()
self.sa = nn.MultiheadAttention(n_embed, n_head) # 我直接调用官网的attention方法
self.smoe = SparseMoE(n_embed, num_experts, top_k)
self.ln1 = nn.LayerNorm(n_embed)
self.ln2 = nn.LayerNorm(n_embed)
def forward(self, x):
qkv = self.ln1(x)
x = x + self.sa(qkv,qkv,qkv)[0] # 并不包含FFN结构
x = x + self.smoe(self.ln2(x))
return x
这段代码的解读:
__init__方法:在初始化方法中,定义了模块的结构。模块包含了一个nn.MultiheadAttention实例sa用于多头自注意力计算,我直接使用pytorch官网的方法模块,一个SparseMoE实例smoe用于稀疏专家计算,以及两个nn.LayerNorm实例ln1和ln2用于层归一化。
forward方法:在前向传播方法中,首先对输入张量x进行层归一化处理,然后通过多头自注意力模块对处理后的张量进行注意力计算,并将注意力计算结果与原始输入张量相加。接着,将相加后的张量通过稀疏专家模块进行计算,再次与原始输入张量相加。最后返回处理后的张量。
总体来说,这段代码实现了一个混合专家Transformer块,结合了多头自注意力和稀疏专家计算。在前向传播过程中,通过多头自注意力和稀疏专家计算两部分对输入张量进行处理,并保持了张量的维度一致。
注: x = x + self.smoe(self.ln2(x))使用了类似残差方法。
三、专家模块定义
专家模块定义了一个名为Expert的PyTorch模块,代表了一个专家模块,用于对输入进行线性变换和非线性变换。
class Expert(nn.Module):
def __init__(self, n_embd):
super().__init__()
dropout=0.1
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
四、路由门控模块
门控网络,也称为路由,确定哪个专家网络接收来自多头注意力的 token 的输出。举个例子解释路由的机制,假设有 4 个专家,token 需要被路由到前 2 个专家中。首先需要通过线性层将 token 输入到门控网络中。该层将对应于(Batch size,Tokens,n_embed)的输入张量从(2,4,32)维度,投影到对应于(Batch size、Tokens,num_expert)的新形状:(2、4,4)。其中 n_embed 是输入的通道维度,num_experts 是专家网络的计数。
class NoisyTopkRouter(nn.Module):
def __init__(self, n_embed, num_experts, top_k):
super(NoisyTopkRouter, self).__init__()
self.top_k = top_k
self.topkroute_linear = nn.Linear(n_embed, num_experts)
# add noise
self.noise_linear = nn.Linear(n_embed, num_experts)
def forward(self, mh_output):
# mh_ouput is the output tensor from multihead self attention block
logits = self.topkroute_linear(mh_output)
# Noise logits
noise_logits = self.noise_linear(mh_output)
# Adding scaled unit gaussian noise to the logits
noise = torch.randn_like(logits) * F.softplus(noise_logits)
noisy_logits = logits + noise
top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1) # 在4个专家中选择最好的专家
zeros = torch.full_like(noisy_logits, float('-inf'))
sparse_logits = zeros.scatter(-1, indices, top_k_logits)
router_output = F.softmax(sparse_logits, dim=-1)
return router_output, indices
五、稀疏MoE集成模块
通过路由获得indices在对每个专家循环获得对应mask,挑选有用的信息不断叠加而获得最终结果,我已改成伪代码如下说明:
# 1. 输入进入router得到两个输出
gating_output, indices = router(x)
# 2.初始化全零矩阵,后续叠加为最终结果
final_output = zeros_like(x)
# 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平
flat_x = flatten(x)
flat_gating_output = flatten(gating_output)
# 4. 对每个专家进行操作
for each expert in experts:
# 5. 查看当前专家对哪些tokens在前top_k
expert_mask = check_top_k_indices(indices, expert)
# 6. 获取当前专家作用的token输入
expert_input = select_expert_input(flat_x, expert_mask)
# 7. 将token输入经过专家处理得到输出
expert_output = expert(expert_input)
# 8. 计算当前专家对作用的token的权重分数
gating_scores = calculate_gating_scores(flat_gating_output, expert_mask)
# 9. 将expert输出乘上权重分数
weighted_output = expert_output * gating_scores
# 10. 将结果叠加到最终输出中
final_output += weighted_output
return final_output
总结解释:
这段代码实现了一个稀疏专家(SparseMoE)模块,结合了一个路由器(router)和多个专家(experts)。
输入首先通过路由器得到两个输出:一个是门控输出(gating_output),一个是索引(indices)。
然后初始化一个全零矩阵作为最终输出。
对每个专家进行操作,根据索引找出每个专家对哪些token起作用,然后将这些token输入到对应的专家中进行处理,根据门控输出的权重分配,将专家处理后的输出加权叠加到最终输出中。
最终返回加权叠加后的最终输出。
其源码如下:
class SparseMoE(nn.Module):
def __init__(self, n_embed, num_experts, top_k):
super(SparseMoE, self).__init__()
self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
self.top_k = top_k
def forward(self, x):
# 1. 输入进入router得到两个输出
gating_output, indices = self.router(x) # x.shape=[2,6,64] gating_output.shape=[2,6,4]
# 2.初始化全零矩阵,后续叠加为最终结果
final_output = torch.zeros_like(x) # final_output.shape=[2,6,64]
# 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平
flat_x = x.view(-1, x.size(-1)) # flat_x.shape=[12,64]
flat_gating_output = gating_output.view(-1, gating_output.size(-1)) # flat_gating_output.shape=[12,4]
# 以每个专家为单位进行操作,即把当前专家处理的所有token都进行加权
for i, expert in enumerate(self.experts):
# 4. 对当前的专家(例如专家0)来说,查看其对所有tokens中哪些在前top2
expert_mask = (indices == i).any(dim=-1) # expert_mask.shape=[2,6]
# 5. 展平操作
flat_mask = expert_mask.view(-1) # flat_mask=12
# 如果当前专家是任意一个token的前top2
if flat_mask.any():
# 6. 得到该专家对哪几个token起作用后,选取token的维度表示
expert_input = flat_x[flat_mask] # 假设第0个专家选择了2个batch的7个维度为[7,64]
# 7. 将token输入expert得到输出
expert_output = expert(expert_input) # 这个专家选择通道走了mlp结构,[7,64]
# 8. 计算当前专家对于有作用的token的权重分数
gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1) # 使用flat_mask选择第i个专家的权重
# 9. 将expert输出乘上权重分数
weighted_output = expert_output * gating_scores
# 10. 循环进行做种的结果叠加,也就是越重要被专家选择越多就叠加越多
final_output[expert_mask] += weighted_output.squeeze(1) # 变成[7,64]
return final_output
六、完整MoE的Demo
最后,给出即插即用的MoE完整Demo代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
def __init__(self, n_embd):
super().__init__()
dropout=0.1
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class NoisyTopkRouter(nn.Module):
def __init__(self, n_embed, num_experts, top_k):
super(NoisyTopkRouter, self).__init__()
self.top_k = top_k
self.topkroute_linear = nn.Linear(n_embed, num_experts)
# add noise
self.noise_linear = nn.Linear(n_embed, num_experts)
def forward(self, mh_output):
# mh_ouput is the output tensor from multihead self attention block
logits = self.topkroute_linear(mh_output)
# Noise logits
noise_logits = self.noise_linear(mh_output)
# Adding scaled unit gaussian noise to the logits
noise = torch.randn_like(logits) * F.softplus(noise_logits)
noisy_logits = logits + noise
top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1) # 在4个专家中选择最好的专家
zeros = torch.full_like(noisy_logits, float('-inf'))
sparse_logits = zeros.scatter(-1, indices, top_k_logits)
router_output = F.softmax(sparse_logits, dim=-1)
return router_output, indices
class SparseMoE(nn.Module):
def __init__(self, n_embed, num_experts, top_k):
super(SparseMoE, self).__init__()
self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
self.top_k = top_k
def forward(self, x):
# 1. 输入进入router得到两个输出
gating_output, indices = self.router(x) # x.shape=[2,6,64] gating_output.shape=[2,6,4]
# 2.初始化全零矩阵,后续叠加为最终结果
final_output = torch.zeros_like(x) # final_output.shape=[2,6,64]
# 3.展平,即把每个batch拼接到一起,这里对输入x和router后的结果都进行了展平
flat_x = x.view(-1, x.size(-1)) # flat_x.shape=[12,64]
flat_gating_output = gating_output.view(-1, gating_output.size(-1)) # flat_gating_output.shape=[12,4]
# 以每个专家为单位进行操作,即把当前专家处理的所有token都进行加权
for i, expert in enumerate(self.experts):
# 4. 对当前的专家(例如专家0)来说,查看其对所有tokens中哪些在前top2
expert_mask = (indices == i).any(dim=-1) # expert_mask.shape=[2,6]
# 5. 展平操作
flat_mask = expert_mask.view(-1) # flat_mask=12
# 如果当前专家是任意一个token的前top2
if flat_mask.any():
# 6. 得到该专家对哪几个token起作用后,选取token的维度表示
expert_input = flat_x[flat_mask] # 假设第0个专家选择了2个batch的7个维度为[7,64]
# 7. 将token输入expert得到输出
expert_output = expert(expert_input) # 这个专家选择通道走了mlp结构,[7,64]
# 8. 计算当前专家对于有作用的token的权重分数
gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1) # 使用flat_mask选择第i个专家的权重
# 9. 将expert输出乘上权重分数
weighted_output = expert_output * gating_scores
# 10. 循环进行做种的结果叠加,也就是越重要被专家选择越多就叠加越多
final_output[expert_mask] += weighted_output.squeeze(1) # 变成[7,64]
return final_output
class Block(nn.Module):
"""Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) """
def __init__(self, n_embed, n_head, num_experts, top_k):
super().__init__()
self.sa = nn.MultiheadAttention(n_embed, n_head) # 我直接调用官网的attention方法
self.smoe = SparseMoE(n_embed, num_experts, top_k)
self.ln1 = nn.LayerNorm(n_embed)
self.ln2 = nn.LayerNorm(n_embed)
def forward(self, x):
qkv = self.ln1(x)
x = x + self.sa(qkv,qkv,qkv)[0] # 并不包含FFN结构
x = x + self.smoe(self.ln2(x))
return x
if __name__ == '__main__':
# 假设分词64个embed与4个头
n_embed, n_head = 64, 4
dropout = 0.1
# 假设4个专家与2个top_k
num_experts, top_k = 4, 2
# 假设数据是2个batch、6个分词,后面代码注释都已这些假设具体化
inputs = torch.rand(2, 6, n_embed)
M = Block(n_embed, n_head, num_experts, top_k)
y = M(inputs)
print(y.shape)
335

被折叠的 条评论
为什么被折叠?



