MOELayer DEMO及注释
import copy
import torch
from typing import Any
from typing import Callable, Dict, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
torch.manual_seed(42)
class MixtralParallelMLPBM(Module):
def __init__(self,hidden_size,ffn_hidden_size):
super(MixtralParallelMLPBM,self).__init__()
self.w1 = torch.nn.Linear(hidden_size,ffn_hidden_size)
self.w2 = torch.nn.Linear(ffn_hidden_size,hidden_size)
self.w3 = torch.nn.Linear(hidden_size,ffn_hidden_size)
self.act_fn = F.silu
def forward(self, hidden_states):
print("hidden_states:",hidden_states.shape)
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class Experts(torch.nn.Module):
def __init__(self, expert, num_local_experts=1):
super(Experts, self).__init__()
self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
self.num_local_experts = num_local_experts
def forward(self, inputs):
print("Experts input:",inputs.shape)
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
print("chunk:",chunk.shape)
chunk = torch.squeeze(chunk, dim=1).contiguous()
print("chunk:",chunk.shape)
out = expert(chunk)
print("expert out:",out.shape)
if type(out) is tuple:
out, bias = out
if bias is not None:
out = out + bias
out = torch.unsqueeze(out, dim=1)
expert_outputs += [out]
expert_output = torch.cat(expert_outputs, dim=1)
return expert_output
def _one_hot_to_float(x, num_classes):
return F.one_hot(x, num_classes=num_classes).float()
def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor):
# gates has shape of S,E
num_tokens = gates.shape[0]
num_experts = gates.shape[1]
max_capacity = num_tokens
# to(torch.int64) works around a bug in torch.onnx.export:
# it should cast k to int64 when converting torch.topk but it doesn't.
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
if capacity < min_capacity:
capacity = min_capacity.to(torch.int64)
elif capacity > max_capacity:
capacity = torch.tensor(max_capacity, dtype=torch.int64)
return capacity
def top1gating(logits):
"""Implements Top1Gating on logits."""
# everything is in fp32 in this function
# token_sel_expert_weights: [S, E], 每个token选择每个专家的概率
token_sel_expert_weights = F.softmax(logits, dim=1) #16,4
print(f"5.softmax:{token_sel_expert_weights.shape} token_sel_expert_weights:\n{token_sel_expert_weights}")
'''
tensor([[0.5426, 0.1172, 0.0655, 0.2747],
[0.1293, 0.1390, 0.1795, 0.5521],
[0.5180, 0.0419, 0.2816, 0.1584],
[0.2191, 0.2966, 0.1691, 0.3152],
[0.2212, 0.3157, 0.1812, 0.2819],
[0.1572, 0.2165, 0.2931, 0.3332],
[0.3198, 0.0820, 0.2499, 0.3483],
[0.1738, 0.1981, 0.1453, 0.4828],
[0.1618, 0.2546, 0.1643, 0.4193],
[0.2306, 0.1819, 0.2694, 0.3181],
[0.1739, 0.0921, 0.1228, 0.6112],
[0.1355, 0.2796, 0.1024, 0.4826],
[0.3720, 0.1553, 0.1946, 0.2781],
[0.2496, 0.4208, 0.1395, 0.1901],
[0.2637, 0.1050, 0.2761, 0.3551],
[0.2899, 0.1759, 0.3855, 0.1488]]
'''
capacity = _capacity(token_sel_expert_weights, torch.tensor(1.1),torch.tensor(4))
print("6.top1gating capacity:",capacity)
# [S] 每个token对应的专家(取概率最大的)
token_sel_expert_idx = torch.argmax(token_sel_expert_weights, dim=1) #[16]
print("7.每个token对应的专家:",token_sel_expert_idx.shape,"data:",token_sel_expert_idx)
#[3, 1, 3, 0, 1, 0, 2, 0, 3, 0, 1, 1, 1, 1, 0, 2]
num_experts = int(token_sel_expert_weights.shape[1])
token_sel_expert_mask = F.one_hot(token_sel_expert_idx, num_classes=num_experts)
print("8.one_hot 编码:",token_sel_expert_mask.shape,"token_sel_expert_mask:\n",token_sel_expert_mask) #16,4
'''
tensor([[1, 0, 0, 0],
[0, 0, 0, 1],
[1, 0, 0, 0],
[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 0]])
'''
# 通过topC每个专家选择至多C个token,然后和原始的mask1(每个专家可能选择超过C个token)矩阵相乘,
# 丢掉超过专家容量的权重低的token,更新得到 token_sel_expert_mask
expert_sel_top_c_token_idx = torch.topk(token_sel_expert_mask, k=capacity, dim=0)[1]
#5,4
print(f"9:获取top{capacity}:{expert_sel_top_c_token_idx.shape} expert_sel_top_c_token_idx:\n{expert_sel_top_c_token_idx}")
'''
tensor([[ 0, 4, 15, 1],
[ 2, 13, 0, 3],
[12, 0, 1, 5],
[ 1, 1, 2, 6],
[ 3, 2, 3, 7]])
'''
mask=torch.zeros_like(token_sel_expert_mask).scatter_(0, expert_sel_top_c_token_idx,1)
print(f"10.将上面index所在的位置填成1:{mask.shape},mask:\n{mask}")
'''
tensor([[1, 1, 1, 0],
[1, 1, 1, 1],
[1, 1, 1, 0],
[1, 0, 1, 1],
[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 0]])
'''
token_sel_expert_mask *= mask
print(f"11.生成最后的mask:{token_sel_expert_mask.shape} token_sel_expert_mask:\n{token_sel_expert_mask}")
'''
tensor([[1, 0, 0, 0],
[0, 0, 0, 1],
[1, 0, 0, 0],
[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 0]])
'''
# Normalize gate probabilities
token_sel_expert_mask_float = token_sel_expert_mask.float()
token_sel_expert_weights = token_sel_expert_weights * token_sel_expert_mask_float
print(f"12.用mask去取softmax后的值:{token_sel_expert_weights.shape},token_sel_expert_weights:\n{token_sel_expert_weights}")
'''
tensor([[0.5426, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.5521],
[0.5180, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.3152],
[0.0000, 0.3157, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.3332],
[0.0000, 0.0000, 0.0000, 0.3483],
[0.0000, 0.0000, 0.0000, 0.4828],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.3720, 0.0000, 0.0000, 0.0000],
[0.0000, 0.4208, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.3855, 0.0000]]
'''
token_idx_in_expert_with_noise = torch.cumsum(token_sel_expert_mask, dim=0) - 1
print(f"13.token_idx_in_expert_with_noise:{token_idx_in_expert_with_noise.shape} token_idx_in_expert_with_noise:\n{token_idx_in_expert_with_noise}")
'''
tensor([[ 0, -1, -1, -1],
[ 0, -1, -1, 0],
[ 1, -1, -1, 0],
[ 1, -1, -1, 1],
[ 1, 0, -1, 1],
[ 1, 0, -1, 2],
[ 1, 0, -1, 3],
[ 1, 0, -1, 4],
[ 1, 0, -1, 4],
[ 1, 0, -1, 4],
[ 1, 0, -1, 4],
[ 1, 0, -1, 4],
[ 2, 0, -1, 4],
[ 2, 1, -1, 4],
[ 2, 1, -1, 4],
[ 2, 1, 0, 4]])
'''
masked_token_idx_in_expert = token_idx_in_expert_with_noise * token_sel_expert_mask
print(f"14.masked_token_idx_in_expert:{masked_token_idx_in_expert.shape} masked_token_idx_in_expert:\n{masked_token_idx_in_expert}")
'''
tensor([[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 1],
[0, 0, 0, 0],
[0, 0, 0, 2],
[0, 0, 0, 3],
[0, 0, 0, 4],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[2, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]])
'''
token_offset_for_expert = torch.sum(masked_token_idx_in_expert, dim=1)
print(f"15.token_offset_for_expert:{token_offset_for_expert.shape} token_offset_for_expert:\n{token_offset_for_expert}")
'''
tensor([0, 0, 1, 1, 0, 2, 3, 4, 0, 0, 0, 0, 2, 1, 0, 0])
'''
token_locations_sc = _one_hot_to_float(token_offset_for_expert, capacity)
print(f"16.token_locations_sc:{token_locations_sc.shape} token_locations_sc:\n{token_locations_sc}")
'''
tensor([[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 1., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.]])
'''
combine_weights = torch.einsum("se,sc->sec", token_sel_expert_weights, token_locations_sc)#16,4 16,5 -> 16,4,5 #每一个token,在4个专家,哪一个容器里
print(f"17.combine_weights:{combine_weights.shape} combine_weights:\n{combine_weights}")
'''
tensor([[[0.5426, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5521, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.5180, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.3152, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3157, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.3332, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.3483, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.4828]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.3720, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.4208, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3855, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]
'''
dispatch_mask = combine_weights.bool()
return combine_weights, dispatch_mask
class TopKGate(Module):
weight: torch.nn.Linear
def __init__(self,hidden_size, num_experts) -> None:
super(TopKGate,self).__init__()
self.weight = torch.nn.Linear(hidden_size, num_experts, bias=False).float()
def forward(self, gate_input):
input_fp32 = gate_input.float()
logits = torch.nn.functional.linear(input_fp32, weight=self.weight.weight.float(), bias=None)
print("4.TopKGate输入:",input_fp32.shape,"权值:",self.weight.weight.shape,"logits输出:",logits.shape)
#16, 4
gate_output = top1gating(logits)
return gate_output
class MOELayer(Module):
def __init__(self,
gate: Module,
experts: Module,
ep_size,
num_local_experts,
pipe_experts: bool = False,
sequence_parallel: bool = True,
pipe_experts_multi_data: int = 1,
pipe_experts_multi_stream: bool = False) -> None:
super().__init__()
self.gate = gate
self.experts = experts
self.ep_group = None
self.ep_size = ep_size
self.num_local_experts = num_local_experts
self.num_experts = ep_size * num_local_experts
self.exp_counts = None
self.l_aux = None
def set_ep_group(self, ep_group):
self.ep_group = ep_group
def forward(self, input, **kwargs):
'''
一.目的:不同的expert负责不同的token
二.主要步骤:
1.生成特征分解矩阵,将输入token的特征拆解放在E(专家个数)C(每个专家的容器数)M(每个token的特征)的容器中
矩阵每一个坐标内的值代表在某个维度上按多少比列分解特征,如果在某个维度上求和,就相当于对拆分后的特征进行加权求和
2.通过矩阵乘,将输入token的特征拆解到以上矩阵(相当于用ECM的容器在装载、交换、变换特征,最后再将这个拆解后的特征加权[矩阵乘]还原到原始的维度)
3.通过all2ll将分在不同ep rank的特征拉到各自己对应expert所在的rank上
4.每个ep节点负责num_local_experts个expert。将上面的特征拆成num_local_experts块,分别送给不同的expert,之后合并结果
5.将上面的结果通过all2all还原回之前每个RANK的排列顺序
6.将分开的特征加权合并,生成(seq_len,hidden_size)的维度
'''
#input: 16,1,64
d_model = input[0].shape[-1]
reshaped_input = input[0].reshape(-1, d_model)
#reshaped_input:16,64
print("3.将维度转换为二维度(seq_len*batch_size,hidden_size):",reshaped_input.shape)
# gate
combine_weights, dispatch_mask = self.gate(reshaped_input)
print(combine_weights.shape,dispatch_mask.shape)
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) #16,4,5 16,64 -> 4 5 64 4个专家,5个容器,每个器放64个feature
#将特征放在固定大小的容器里,防止了不均衡
#dispatch_mask是token的分配矩阵,reshaped_input是每个token的特征,结果相当于将reshaped_input放在dispatch_mask里(加权存放)
print(f"18.dispatched_input:{dispatched_input.shape},dispatched_input:\n{dispatched_input}")
# dispatch all2all
#ep是对expert进行拆分,每个expert承接一部分输入,all2ll之后是将前一半数据放在rank0,后一半数据放在rank1
#dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
#ep个专家组,每个组里2个专家,每个专用一个MixtralParallelMLPBM去提特征,每个MixtralParallelMLPBM的按tp并行计算,最后拼接在一起
print("dispatched_input:",dispatched_input.shape)
#每个expert计算一部分特征
expert_output = self.experts(dispatched_input)
print("expert_output:",expert_output.shape)
# combine all2all
#将特征还原回之前每个RANK的排列顺序,其实就相当于,通过all2all将特征当前对应的专家所在的rank上计算,计算完之后再放回去
#expert_output = _AllToAll.apply(self.ep_group, expert_output)
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
##16,4,5 4,5,64 -> 16,64 将分开的特征加权合并,输出最终的16,64的特征
combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
return combined_output.reshape(input[0].shape)
'''
模型配置
TP=2
PP=2
DP=2
EP=2
num_experts=4
'''
'''
all tp gourps [[0, 1], [2, 3], [4, 5], [6, 7]]
all ep groups [[0, 2], [1, 3], [4, 6], [5, 7]]
all dp groups [[0, 2], [1, 3], [4, 6], [5, 7]]
all pp gourps [[0, 4], [1, 5], [2, 6], [3, 7]]
'''
def main():
num_experts=4
num_local_experts=2
seq_len=16
batch_size=1
hidden_size=8
ffn_hidden_size=16
ep_size=2
gate = TopKGate(hidden_size,num_experts)
moe = MOELayer(gate,
Experts(MixtralParallelMLPBM(hidden_size,ffn_hidden_size),
num_local_experts),
ep_size,
num_local_experts)
input=torch.randn(seq_len, batch_size, hidden_size,dtype=torch.float32)
print("1.原始的输入shape(32,1,64),因为序列并行,进入到MOE时维度为(16,1,64)")
print("2.输入数列的shpae:",input.shape)
output = moe([input])
print("output:",output.shape)
main()