本文旨在分析Dynamic router moe的实现
原始论文实现
def top_p_sampling_batched_all_sequence(logits, top_p=0.9, temperature=1.0):
"""
Apply Top-p sampling to every element in the sequence for each item in the batch.
Returns the selected token indices and the corresponding threshold indices.
:param logits: Logits from a language model with shape (sequence length, batch size, L)
:param top_p: Cumulative probability threshold (float)
:param temperature: Sampling temperature (float)
:return: Tuple of tensors (selected token indices, threshold indices) for each position in each sequence in the batch
"""
# Apply temperature
logits = logits / temperature
# Convert logits to probabilities
# probabilities = torch.softmax(logits, dim=-1)
# Sort probabilities and their indices in descending order
sorted_probs, sorted_indices = torch.sort(logits, descending=True)
# Compute cumulative probabilities
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumulative_probs > top_p
# Find the threshold indices
threshold_indices = mask.long().argmax(dim=-1)
threshold_mask = torch.nn.functional.one_hot(threshold_indices, num_classes=sorted_indices.size(-1)).bool()
mask = mask & ~threshold_mask
sorted_indices = torch.where(mask, -1, sorted_indices)
sorted_probs = torch.where(mask, 0.0, sorted_probs)
return sorted_probs, sorted_indices
class SwitchMLP(nn.Module):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, config, layer_idx):
super(SwitchMLP, self).__init__()
self.layer_num = layer_idx
self.use_switch = (layer_idx % config.expert_frequency) == 0 # Ensure the first layer use switch mlp
if self.use_switch:
self.top_p_threshold = config.top_p_threshold
self.router = torch.nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = torch.nn.ModuleList()
self.num_experts = config.num_experts
for i in range(config.num_experts):
self.experts.append(LlamaMLP(config.hidden_size, config.intermediate_size, config.hidden_act))
else:
self.mlp = LlamaMLP(config.hidden_size, config.intermediate_size, config.hidden_act)
def forward(self, hidden_states):
if not self.use_switch:
output = self.mlp(hidden_states)
return output
s = hidden_states.size(0)
b = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
topk_weights, topk_ind = top_p_sampling_batched_all_sequence(route, self.top_p_threshold)
hidden_states = hidden_states.view(-1, hidden_states.size(2))
topk_weights = topk_weights.view(-1, topk_weights.size(2))
topk_ind = topk_ind.view(-1, topk_ind.size(2))
output_total = torch.zeros_like(hidden_states).to(hidden_states)
for expert_num, expert in enumerate(self.experts):
sample_ind, expert_ind = torch.where(topk_ind == expert_num)
hidden = hidden_states[sample_ind.unsqueeze(1), :]
expert_output = expert(hidden)
output_total[sample_ind] += torch.mul(expert_output.squeeze(1), topk_weights[sample_ind,expert_ind].unsqueeze(1))
output_total = output_total.view(s, b, h)
return output_total
代码理解
top_p_sampling_batched_all_sequence()方法
1. 方法功能
- 输入:
- logits: 形状为 [S, B, K],表示序列长度 × 批次大小 × 专家数量
- top_p:浮点数,Top-p 采样的累积概率阈值(如 0.9 表示只保留使累计概率达到 90% 的 token/专家) temperature:
- 温度参数,控制 softmax 分布的“锐利程度”,越高越随机,越低越确定
- 输出:
- sorted_probs: 每个 token中满足 Top-p 条件的专家概率(其余置为 0)
- sorted_indices:对应的专家索引(不满足条件的置为 -1)
2. 逐行分析
- 对 logits 应用温度缩放。
- 如果 temperature > 1:分布更平坦,采样更随机。
- 如果 temperature < 1:分布更尖锐,偏向高概率项。
logits = logits / temperature #[S, B, K]
- 对 logits 降序排序,得到最大到最小的 logit 值及其原始索引。
sorted_probs, sorted_indices = torch.sort(logits, descending=True) #[S, B, K], [S, B, K]
- 计算排序后的累计概率,举例:
- 原始概率 [0.5, 0.3, 0.2]
- 累计概率 [0.5, 0.8, 1.0]
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) #[S, B, K]
- 构建一个布尔掩码,标记哪些位置的累计概率已经超过了 top_p 阈值,则超过 top_p 的位置会被标记为 True。
mask = cumulative_probs > top_p #[S, B, K]
- 找到第一个超过 top_p 的索引(即截止点)。
- argmax 会返回第一个 True 的位置。
- 例如 [False, False, True, True] → argmax = 2
threshold_indices = mask.long().argmax(dim=-1) #[S, B, 1]
- 将threshold_indices转换为one-hot,防止前面所有都被 mask 掉的情况(即至少保留一个专家)
threshold_mask = torch.nn.functional.one_hot(threshold_indices, num_classes=sorted_indices.size(-1)).bool() #[S, B, K]
- 将原本 mask 中超过 top_p 的位置设为 True,但把第一个超过的位置设为 False(即保留该位置)。通俗点说,所有超过 top_p 的都不要了,除了第一个超过的要保留下来
mask = mask & ~threshold_mask #[S, B, K]
- 将超出 Top-p 的专家索引设为 -1,对应概率设为 0.0
sorted_indices = torch.where(mask, -1, sorted_indices) #[S, B, K]
sorted_probs = torch.where(mask, 0.0, sorted_probs)#[S, B, K]
Class SwitchMLP
1. 模块功能
这个模块是 LLaMA 或类似模型中原始 MLP 的替代版本,其核心思想是:
- 在某些层使用 Switch MoE 架构:每个 token 只路由给一个或多个专家处理。
- 路由方式为 Top-p 采样(而不是传统的Top-k),即选择累计概率达到一定阈值的专家。
- 其余层则使用标准的全连接 MLP
2. 初始化函数
- layer_idx: 当前层的编号(从 0 开始)
- config.expert_frequency: 表示每隔多少层使用一次 Switch MoE(比如每 2 层用一次)
- use_switch: 控制当前层是否启用 MoE 模块
def __init__(self, config, layer_idx):
super(SwitchMLP, self).__init__()
self.layer_num = layer_idx
self.use_switch = (layer_idx % config.expert_frequency) == 0
如果启用 Switch MoE:
- top_p_threshold: Top-p 采样的阈值(如 0.9)
- router: 一个线性层,将输入映射到每个专家的 logits
- experts: 多个专家模块,这里使用的是 LlamaMLP(与原始 LLaMA 相同结构的 FFN)
否则使用普通 MLP。
if self.use_switch:
self.top_p_threshold = config.top_p_threshold
self.router = torch.nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = torch.nn.ModuleList()
self.num_experts = config.num_experts
for i in range(config.num_experts):
self.experts.append(LlamaMLP(config.hidden_size, config.intermediate_size, config.hidden_act))
else:
self.mlp = LlamaMLP(config.hidden_size, config.intermediate_size, config.hidden_act)
3. 前向传播函数 forward()
如果不使用 MoE,则直接调用普通 MLP 输出结果。
if not self.use_switch:
output = self.mlp(hidden_states)
return output
如果使用 MoE (关键逻辑)
- 获取输入形状并使用Linear层计算路由权重
- hidden_states 形状为 [S, B, H]:序列长度 × 批次大小 × 隐藏维度
- route 是通过 router 得到的专家 logit 分布,然后经过 softmax 转换为概率分布
- route.shape = [S, B, K],其中 K 是专家数量
s = hidden_states.size(0)
b = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
- 使用 Top-p 采样获取专家权重和索引
- 输入:专家概率分布 route 和 Top-p 阈值
- 输出:
- topk_weights: 每个 token 选中的专家们的权重(概率)
- topk_ind: 对应的专家们的索引
- 输出形状:[S, B, K],K 是每个 token 选择的专家数(可能不同)
topk_weights, topk_ind = top_p_sampling_batched_all_sequence(route, self.top_p_threshold)
topk_weights, topk_ind其实是sorted_probs, sorted_indices
维度都是#[S, B, K]
- 展平张量以便后续处理
- hidden_states: [S, B, H] → [S*B, H]
- topk_weights, topk_ind:[S, B, K] → [S*B, K]
hidden_states = hidden_states.view(-1, hidden_states.size(2))
topk_weights = topk_weights.view(-1, topk_weights.size(2))
topk_ind = topk_ind.view(-1, topk_ind.size(2))
- 遍历每个专家,计算输出并加权累加
- 创建一个全零张量用于保存所有 token 经过专家处理后的加权输出
- 遍历每一个专家(expert_num 表示当前专家编号)。
- 使用 torch.where(topk_ind == expert_num) 获取哪些 token 选择了当前专家。
- sample_ind: 选中当前专家的所有 token 的索引。一维张量 [N],其中 N 是选中当前专家的 token 数量。
- expert_ind: 在 topk_ind 中,这些 token 选择的是第几个专家。与 sample_ind 相同,是一维张量 [N]。由于 topk_ind 的形状是 [S*B, K],所以 expert_ind 的值范围是 [0, K-1]。
- 从 hidden_states 中提取对应的输入到hidden中。
- sample_ind 是一个一维张量,包含选中当前专家的所有 token 的索引。为了正确地从 hidden_states 中提取这些 token 的隐藏状态,我们需要使用 .unsqueeze(1) 将其扩展为二维索引。
- 这里的 hidden 形状是 [N, H],其中 N 是选中当前专家的 token 数量,H 是隐藏层的维度。
- 每个专家对分配给它的 token 进行处理,生成输出 expert_output,形状为 [N, H]
- 将专家的输出加权后累加到最终输出中,只更新那些选中当前专家的 token
- topk_weights 的原始形状是 [SB, K],其中 SB 是展平后的 token 总数,K 是专家总数量。
- sample_ind 和 expert_ind 都是一维张量 [N],表示选中当前专家的所有 token 的索引及其在 topk_ind 中的位置。
- 因此,topk_weights[sample_ind, expert_ind] 的形状是 [N],即每个选中当前专家的 token 对应的权重。
- 因为每个专家可以被多个Token选择,因此这里在专家维度进行循环,在Token维度进行累积
output_total = torch.zeros_like(hidden_states).to(hidden_states)
for expert_num, expert in enumerate(self.experts):
sample_ind, expert_ind = torch.where(topk_ind == expert_num)
hidden = hidden_states[sample_ind.unsqueeze(1), :]
expert_output = expert(hidden)
output_total[sample_ind] += torch.mul(expert_output.squeeze(1), topk_weights[sample_ind,expert_ind].unsqueeze(1))
- 最后,恢复原始形状并返回
output_total = output_total.view(s, b, h)
return output_total