MOELayer DEMO及注释

MOELayer DEMO及注释

import copy
import torch
from typing import Any
from typing import Callable, Dict, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
torch.manual_seed(42)

class MixtralParallelMLPBM(Module):
    def __init__(self,hidden_size,ffn_hidden_size):
        super(MixtralParallelMLPBM,self).__init__()
        self.w1 = torch.nn.Linear(hidden_size,ffn_hidden_size)
        self.w2 = torch.nn.Linear(ffn_hidden_size,hidden_size)
        self.w3 = torch.nn.Linear(hidden_size,ffn_hidden_size)
        self.act_fn = F.silu
    def forward(self, hidden_states):
        print("hidden_states:",hidden_states.shape)
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states
 
class Experts(torch.nn.Module):
    def __init__(self, expert, num_local_experts=1):
        super(Experts, self).__init__()
        self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
        self.num_local_experts = num_local_experts
        
    def forward(self, inputs):
        print("Experts input:",inputs.shape)
        chunks = inputs.chunk(self.num_local_experts, dim=1)
        expert_outputs = []
        for chunk, expert in zip(chunks, self.experts):
            print("chunk:",chunk.shape)
            chunk = torch.squeeze(chunk, dim=1).contiguous()
            print("chunk:",chunk.shape)
            out = expert(chunk)
            print("expert out:",out.shape)
            if type(out) is tuple:
                out, bias = out
                if bias is not None:
                    out = out + bias
            out = torch.unsqueeze(out, dim=1)
            expert_outputs += [out]

        expert_output = torch.cat(expert_outputs, dim=1)
        return expert_output
    
def _one_hot_to_float(x, num_classes):
    return F.one_hot(x, num_classes=num_classes).float()

def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor):
    # gates has shape of S,E
    num_tokens = gates.shape[0]
    num_experts = gates.shape[1]
    max_capacity = num_tokens
    # to(torch.int64) works around a bug in torch.onnx.export:
    # it should cast k to int64 when converting torch.topk but it doesn't.
    capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
    if capacity < min_capacity:
        capacity = min_capacity.to(torch.int64)
    elif capacity > max_capacity:
        capacity = torch.tensor(max_capacity, dtype=torch.int64)
    return capacity
    
def top1gating(logits):
    """Implements Top1Gating on logits."""

    # everything is in fp32 in this function
    # token_sel_expert_weights: [S, E], 每个token选择每个专家的概率
    
    token_sel_expert_weights = F.softmax(logits, dim=1) #16,4
    print(f"5.softmax:{token_sel_expert_weights.shape} token_sel_expert_weights:\n{token_sel_expert_weights}")
    '''
    tensor([[0.5426, 0.1172, 0.0655, 0.2747],
            [0.1293, 0.1390, 0.1795, 0.5521],
            [0.5180, 0.0419, 0.2816, 0.1584],
            [0.2191, 0.2966, 0.1691, 0.3152],
            [0.2212, 0.3157, 0.1812, 0.2819],
            [0.1572, 0.2165, 0.2931, 0.3332],
            [0.3198, 0.0820, 0.2499, 0.3483],
            [0.1738, 0.1981, 0.1453, 0.4828],
            [0.1618, 0.2546, 0.1643, 0.4193],
            [0.2306, 0.1819, 0.2694, 0.3181],
            [0.1739, 0.0921, 0.1228, 0.6112],
            [0.1355, 0.2796, 0.1024, 0.4826],
            [0.3720, 0.1553, 0.1946, 0.2781],
            [0.2496, 0.4208, 0.1395, 0.1901],
            [0.2637, 0.1050, 0.2761, 0.3551],
            [0.2899, 0.1759, 0.3855, 0.1488]]    
    '''
    capacity = _capacity(token_sel_expert_weights, torch.tensor(1.1),torch.tensor(4))
    print("6.top1gating capacity:",capacity)

    # [S] 每个token对应的专家(取概率最大的)
    token_sel_expert_idx = torch.argmax(token_sel_expert_weights, dim=1) #[16]
    print("7.每个token对应的专家:",token_sel_expert_idx.shape,"data:",token_sel_expert_idx)
    #[3, 1, 3, 0, 1, 0, 2, 0, 3, 0, 1, 1, 1, 1, 0, 2]
    num_experts = int(token_sel_expert_weights.shape[1])
    token_sel_expert_mask = F.one_hot(token_sel_expert_idx, num_classes=num_experts)
    print("8.one_hot 编码:",token_sel_expert_mask.shape,"token_sel_expert_mask:\n",token_sel_expert_mask) #16,4
    '''
     tensor([[1, 0, 0, 0],
             [0, 0, 0, 1],
             [1, 0, 0, 0],
             [0, 0, 0, 1],
             [0, 1, 0, 0],
             [0, 0, 0, 1],
             [0, 0, 0, 1],
             [0, 0, 0, 1],
             [0, 0, 0, 1],
             [0, 0, 0, 1],
             [0, 0, 0, 1],
             [0, 0, 0, 1],
             [1, 0, 0, 0],
             [0, 1, 0, 0],
             [0, 0, 0, 1],
             [0, 0, 1, 0]])
    '''

    # 通过topC每个专家选择至多C个token,然后和原始的mask1(每个专家可能选择超过C个token)矩阵相乘,
    # 丢掉超过专家容量的权重低的token,更新得到 token_sel_expert_mask
    expert_sel_top_c_token_idx = torch.topk(token_sel_expert_mask, k=capacity, dim=0)[1]
    #5,4
    print(f"9:获取top{capacity}:{expert_sel_top_c_token_idx.shape} expert_sel_top_c_token_idx:\n{expert_sel_top_c_token_idx}")
    '''
    tensor([[ 0,  4, 15,  1],
            [ 2, 13,  0,  3],
            [12,  0,  1,  5],
            [ 1,  1,  2,  6],
            [ 3,  2,  3,  7]]) 
    '''    
    mask=torch.zeros_like(token_sel_expert_mask).scatter_(0, expert_sel_top_c_token_idx,1)
    print(f"10.将上面index所在的位置填成1:{mask.shape},mask:\n{mask}")
    '''
    tensor([[1, 1, 1, 0],
            [1, 1, 1, 1],
            [1, 1, 1, 0],
            [1, 0, 1, 1],
            [0, 1, 0, 0],
            [0, 0, 0, 1],
            [0, 0, 0, 1],
            [0, 0, 0, 1],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [1, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 1, 0]])    
    '''    
    token_sel_expert_mask *= mask
    print(f"11.生成最后的mask:{token_sel_expert_mask.shape} token_sel_expert_mask:\n{token_sel_expert_mask}")
    '''
    tensor([[1, 0, 0, 0],
            [0, 0, 0, 1],
            [1, 0, 0, 0],
            [0, 0, 0, 1],
            [0, 1, 0, 0],
            [0, 0, 0, 1],
            [0, 0, 0, 1],
            [0, 0, 0, 1],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [1, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 1, 0]])    
    '''
    
    # Normalize gate probabilities
    token_sel_expert_mask_float = token_sel_expert_mask.float()
    token_sel_expert_weights = token_sel_expert_weights * token_sel_expert_mask_float
    print(f"12.用mask去取softmax后的值:{token_sel_expert_weights.shape},token_sel_expert_weights:\n{token_sel_expert_weights}")
    '''
    tensor([[0.5426, 0.0000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000, 0.5521],
            [0.5180, 0.0000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000, 0.3152],
            [0.0000, 0.3157, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000, 0.3332],
            [0.0000, 0.0000, 0.0000, 0.3483],
            [0.0000, 0.0000, 0.0000, 0.4828],
            [0.0000, 0.0000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000, 0.0000],
            [0.3720, 0.0000, 0.0000, 0.0000],
            [0.0000, 0.4208, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.3855, 0.0000]]    
    '''
    token_idx_in_expert_with_noise = torch.cumsum(token_sel_expert_mask, dim=0) - 1
    print(f"13.token_idx_in_expert_with_noise:{token_idx_in_expert_with_noise.shape} token_idx_in_expert_with_noise:\n{token_idx_in_expert_with_noise}")
    '''
    tensor([[ 0, -1, -1, -1],
            [ 0, -1, -1,  0],
            [ 1, -1, -1,  0],
            [ 1, -1, -1,  1],
            [ 1,  0, -1,  1],
            [ 1,  0, -1,  2],
            [ 1,  0, -1,  3],
            [ 1,  0, -1,  4],
            [ 1,  0, -1,  4],
            [ 1,  0, -1,  4],
            [ 1,  0, -1,  4],
            [ 1,  0, -1,  4],
            [ 2,  0, -1,  4],
            [ 2,  1, -1,  4],
            [ 2,  1, -1,  4],
            [ 2,  1,  0,  4]])    
    '''
    masked_token_idx_in_expert = token_idx_in_expert_with_noise * token_sel_expert_mask
    print(f"14.masked_token_idx_in_expert:{masked_token_idx_in_expert.shape} masked_token_idx_in_expert:\n{masked_token_idx_in_expert}")
    '''
    tensor([[0, 0, 0, 0],
            [0, 0, 0, 0],
            [1, 0, 0, 0],
            [0, 0, 0, 1],
            [0, 0, 0, 0],
            [0, 0, 0, 2],
            [0, 0, 0, 3],
            [0, 0, 0, 4],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [2, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0]])    
    '''
    
    token_offset_for_expert = torch.sum(masked_token_idx_in_expert, dim=1)
    print(f"15.token_offset_for_expert:{token_offset_for_expert.shape} token_offset_for_expert:\n{token_offset_for_expert}")
    '''
    tensor([0, 0, 1, 1, 0, 2, 3, 4, 0, 0, 0, 0, 2, 1, 0, 0])
    '''

    token_locations_sc = _one_hot_to_float(token_offset_for_expert, capacity)
    print(f"16.token_locations_sc:{token_locations_sc.shape} token_locations_sc:\n{token_locations_sc}")
    '''
    tensor([[1., 0., 0., 0., 0.],
            [1., 0., 0., 0., 0.],
            [0., 1., 0., 0., 0.],
            [0., 1., 0., 0., 0.],
            [1., 0., 0., 0., 0.],
            [0., 0., 1., 0., 0.],
            [0., 0., 0., 1., 0.],
            [0., 0., 0., 0., 1.],
            [1., 0., 0., 0., 0.],
            [1., 0., 0., 0., 0.],
            [1., 0., 0., 0., 0.],
            [1., 0., 0., 0., 0.],
            [0., 0., 1., 0., 0.],
            [0., 1., 0., 0., 0.],
            [1., 0., 0., 0., 0.],
            [1., 0., 0., 0., 0.]])    
    '''
    combine_weights = torch.einsum("se,sc->sec", token_sel_expert_weights, token_locations_sc)#16,4 16,5 -> 16,4,5 #每一个token,在4个专家,哪一个容器里
    print(f"17.combine_weights:{combine_weights.shape} combine_weights:\n{combine_weights}")
    '''
    tensor([[[0.5426, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.5521, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.5180, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.3152, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.3157, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.3332, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.3483, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.4828]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.3720, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.4208, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

            [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.3855, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]    
    '''
    dispatch_mask = combine_weights.bool()
    return combine_weights, dispatch_mask

class TopKGate(Module):
    weight: torch.nn.Linear
    def __init__(self,hidden_size, num_experts) -> None:
        super(TopKGate,self).__init__()
        self.weight = torch.nn.Linear(hidden_size, num_experts, bias=False).float()
    def forward(self, gate_input):
        input_fp32 = gate_input.float()        
        logits = torch.nn.functional.linear(input_fp32, weight=self.weight.weight.float(), bias=None)
        print("4.TopKGate输入:",input_fp32.shape,"权值:",self.weight.weight.shape,"logits输出:",logits.shape)
        #16, 4
        gate_output = top1gating(logits)
        return gate_output
        
class MOELayer(Module):
    def __init__(self,
                 gate: Module,
                 experts: Module,
                 ep_size,
                 num_local_experts,
                 pipe_experts: bool = False,
                 sequence_parallel: bool = True,
                 pipe_experts_multi_data: int = 1,
                 pipe_experts_multi_stream: bool = False) -> None:
        super().__init__()
        self.gate = gate
        self.experts = experts
        self.ep_group = None
        self.ep_size = ep_size
        self.num_local_experts = num_local_experts
        self.num_experts = ep_size * num_local_experts
        self.exp_counts = None
        self.l_aux = None

    def set_ep_group(self, ep_group):
        self.ep_group = ep_group

    def forward(self, input, **kwargs):
        '''
        一.目的:不同的expert负责不同的token
        二.主要步骤:
        1.生成特征分解矩阵,将输入token的特征拆解放在E(专家个数)C(每个专家的容器数)M(每个token的特征)的容器中
          矩阵每一个坐标内的值代表在某个维度上按多少比列分解特征,如果在某个维度上求和,就相当于对拆分后的特征进行加权求和
        2.通过矩阵乘,将输入token的特征拆解到以上矩阵(相当于用ECM的容器在装载、交换、变换特征,最后再将这个拆解后的特征加权[矩阵乘]还原到原始的维度)
        3.通过all2ll将分在不同ep rank的特征拉到各自己对应expert所在的rank上
        4.每个ep节点负责num_local_experts个expert。将上面的特征拆成num_local_experts块,分别送给不同的expert,之后合并结果
        5.将上面的结果通过all2all还原回之前每个RANK的排列顺序
        6.将分开的特征加权合并,生成(seq_len,hidden_size)的维度
        '''
        
        #input: 16,1,64
        d_model = input[0].shape[-1]
        reshaped_input = input[0].reshape(-1, d_model)
        #reshaped_input:16,64
        print("3.将维度转换为二维度(seq_len*batch_size,hidden_size):",reshaped_input.shape)
        # gate
        combine_weights, dispatch_mask = self.gate(reshaped_input)
        print(combine_weights.shape,dispatch_mask.shape)
        dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) #16,4,5 16,64 -> 4 5 64  4个专家,5个容器,每个器放64个feature
        #将特征放在固定大小的容器里,防止了不均衡
        
        #dispatch_mask是token的分配矩阵,reshaped_input是每个token的特征,结果相当于将reshaped_input放在dispatch_mask里(加权存放)
        print(f"18.dispatched_input:{dispatched_input.shape},dispatched_input:\n{dispatched_input}")

        # dispatch all2all
        #ep是对expert进行拆分,每个expert承接一部分输入,all2ll之后是将前一半数据放在rank0,后一半数据放在rank1
        #dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)

        # Re-shape after all-to-all: ecm -> gecm
        dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
        #ep个专家组,每个组里2个专家,每个专用一个MixtralParallelMLPBM去提特征,每个MixtralParallelMLPBM的按tp并行计算,最后拼接在一起
        print("dispatched_input:",dispatched_input.shape)
        #每个expert计算一部分特征
        expert_output = self.experts(dispatched_input)
        print("expert_output:",expert_output.shape)
        # combine all2all
        #将特征还原回之前每个RANK的排列顺序,其实就相当于,通过all2all将特征当前对应的专家所在的rank上计算,计算完之后再放回去
        #expert_output = _AllToAll.apply(self.ep_group, expert_output)
        
        # Re-shape back: gecm -> ecm
        expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
        ##16,4,5  4,5,64  -> 16,64 将分开的特征加权合并,输出最终的16,64的特征
        combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
        return combined_output.reshape(input[0].shape)

'''
模型配置
TP=2
PP=2
DP=2
EP=2
num_experts=4
'''

'''
all tp gourps [[0, 1], [2, 3], [4, 5], [6, 7]]
all ep groups [[0, 2], [1, 3], [4, 6], [5, 7]]
all dp groups [[0, 2], [1, 3], [4, 6], [5, 7]]
all pp gourps [[0, 4], [1, 5], [2, 6], [3, 7]]
'''

def main():
    num_experts=4
    num_local_experts=2
    seq_len=16
    batch_size=1
    hidden_size=8
    ffn_hidden_size=16
    ep_size=2

    gate = TopKGate(hidden_size,num_experts)
    moe = MOELayer(gate, 
                    Experts(MixtralParallelMLPBM(hidden_size,ffn_hidden_size),
                            num_local_experts),
                    ep_size,
                    num_local_experts)

    input=torch.randn(seq_len, batch_size, hidden_size,dtype=torch.float32)
    print("1.原始的输入shape(32,1,64),因为序列并行,进入到MOE时维度为(16,1,64)")
    print("2.输入数列的shpae:",input.shape)
    output = moe([input])    
    print("output:",output.shape)

main()
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Hi20240217

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值