Mixtral MOE代码理解

我在看MOE的时候,虽然大概能够理解MOE的模型结构,但是看一些作者实现的代码(应该不是官方代码),虽然写的很好,但是始终理解无法彻底理解他代码的意思,于是,简单运行了一下,特此记录一下。
参考博客:
理解Mixtral Moe模型原理与代码实现
Mixtral Moe代码解读
这里就不过多解释MOE的原理,直接粘贴代码:

self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

# 获取到每个token的mlp层输入特征 
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

# 得到每个专家的打分,维度是batch * sequence, num_experts,取topk个专家
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)

# 取到topk个专家的打分,需要计算在归一化一下,用于对后面的expert计算出来的结果进行加权
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# routing_weights、selected_experts 维度是一致的,取了topk   (bs * sl, topk)
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

# 如果不做后面的维度切换,那expert_mask的维度是 (bs*sl, topk, n_experts),但是后面要遍历n_experts来计算,所以颠倒一下,得到(n_experts, topk, bs * sl); 
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

for expert_idx in range(self.num_experts):
    expert_layer = self.experts[expert_idx]
    idx, top_x = torch.where(expert_mask[expert_idx])
    
    """
    这样取到expert_mask[expert_idx],从上面的注释可以知道维度是
    [topk, bs * sl];torch.where的结果,第一个结果代表选到了哪一行,第二个代表选择了哪一列
    
    对应到实际意义,top_x表示取的列,也就是取哪些token
    而行表示,取到的这些token,根据路由gate计算,当前expert是排行第几;
    所以这里变量名字可能有点混淆,
    """
    
    # 没有token需要当前的expert计算
    if top_x.shape[0] == 0:
        continue
    
    # tensor index使用list比tensor快
    top_x_list = top_x.tolist()
    idx_list = idx.tolist()

    # 前面hidden states已经转成了 [bs * sl, hs],根据top_x 可以找到需要计算的token,这些token依旧是有序的
    current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
    
    # 找到这个expert对应的权重 乘进去
    # 上面计算的权重是routing_weights,维度是bs * sl, topk
    # 根据top_x_list 对应的token,idx_list表示topk中第几个
    # 可以直接取到相应的权重
    current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

    # 合到最终的特征里边去
    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
    
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

这部分代码比较简单,属于常规的方式:

设置参数
batch_size = 2
sequence_length = 3
hidden_dim = 4
num_experts = 3
top_k = 2
hidden_states = torch.randn(2, 3, 4)
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

# 获取到每个token的mlp层输入特征 
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

# 得到每个专家的打分,维度是batch * sequence, num_experts,取topk个专家
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

hidden_states结果为:
在这里插入图片描述
经过gate,也就是预测每个专家的概率值:
在这里插入图片描述

routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)

经过topK从三个专家中选择两个,结果如下:
在这里插入图片描述

routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

因为专家少了一个,或者推广到更过的专家,少了几个专家,仅剩的两个专家概率值之和不为1,这个目的就是,概率值归一
在这里插入图片描述

final_hidden_states = torch.zeros(

        (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
    )

则是用于存储最终的结果

expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).permute(2, 1, 0)

我们先看一下: torch.nn.functional.one_hot(selected_experts, num_classes=num_experts)的结果
在这里插入图片描述
其实就是对下面的内容转成one_hot编码:
在这里插入图片描述
permute(2, 1, 0)之后的结果:
在这里插入图片描述
可以理解为专家1负责两部分,一部分是top1的内容,一个是top2的内容

idx, top_x = torch.where(expert_mask[expert_idx])

expert_idx是第几个专家, where实现数值不为空的行和列的坐标,如下:
在这里插入图片描述
在这里插入图片描述
上面两个图是对应的,不为0的坐标分别为(0, 1),(0,3),(0,4)…就是把横纵坐标分开而已。

current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)

当前专家负责的token或者是对应的hidd_states,通过top_x进行获取,可能是出于严谨的问题,我尝试了一下 current_state = hidden_states[top_x_list]这么比较好理解。
在这里插入图片描述

current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

expert_layer(current_state)对current_state进行FFN操作(这里简单实现),routing_weights[top_x_list, idx_list, None]去除对应的权值。

final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

对最终的结果进行累加,加到对应的token上。类似于累加的形式,初始化为0,先加上一个专家的结果,等另外一个专家处理到这个token在加上这个专家的结果。
index_add_的用法如下:
在这里插入图片描述

  • 17
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值