Qwen3中的MoE是如何平衡专家负载的?

近年来,混合专家(Mixture of Experts, MoE)架构因其在扩展模型容量的同时保持计算效率的潜力,在大型语言模型领域受到了广泛关注。Qwen3系列模型也采用了MoE架构,通过稀疏激活特定的“专家”网络来处理不同的输入。然而,MoE模型的一个核心挑战在于如何确保各个专家之间的负载均衡,避免某些专家过载而另一些专家空闲。本文将基于Qwen3的开源代码,深入分析其负载均衡损失函数(load_balancing_loss_func)的设计与实现。

在这里插入图片描述

MoE与负载均衡的重要性

在MoE模型中,一个门控网络(Gating Network)决定将每个输入token路由到哪些专家进行处理。理想情况下,我们希望所有专家都能得到充分利用,并且每个专家都能学到独特的知识。如果路由机制出现偏差,导致大部分token被路由到少数几个专家,就会出现以下问题:

  • 1.专家过载与资源浪费:少数专家计算压力过大,而其他专家则处于空闲状态,导致计算资源利用不均。
  • 2.训练不稳定:不均衡的负载可能导致训练过程不稳定,模型难以收敛。
  • 3.模型性能下降:专家未能充分特化,模型整体性能可能受损。

在这里插入图片描述

因此,引入一个辅助的负载均衡损失函数至关重要,它能够惩罚不均衡的路由行为,鼓励token在专家间均匀分布。

Qwen3中的负载均衡机制借鉴了Switch Transformer论文[1]中的公式(4)至(6):
在这里插入图片描述

辅助损失函数的目标是使每个专家的token分配比例(tokens_per_expert)和路由概率(router_prob_per_expert)尽可能均匀。

让我们一起剖析load_balancing_loss_func函数的具体实现:


def load_balancing_loss_func(
    gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
    num_experts: Optional[int] = None,
    top_k=2,
    attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
    r"""
    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.

    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
    experts is too unbalanced.

    Args:
        gate_logits:
            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
            shape [batch_size X sequence_length, num_experts].
        num_experts:
            Number of experts
        top_k:
            The number of experts to route per-token, can be also interpreted as the `top-k` routing
            parameter.
        attention_mask (`torch.Tensor`, *optional*):
            The attention_mask used in forward function
            shape [batch_size X sequence_length] if not None.

    Returns:
        The auxiliary loss.
    """
    if gate_logits is None or not isinstance(gate_logits, tuple):
        return 0

    if isinstance(gate_logits, tuple):
        compute_device = gate_logits[0].device
        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)

    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)

    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)

    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

    if attention_mask is None:
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.mean(routing_weights, dim=0)
    else:
        batch_size, sequence_length = attention_mask.shape
        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)

        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
        expert_attention_mask = (
            attention_mask[None, :, :, None, None]
            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
            .reshape(-1, top_k, num_experts)
            .to(compute_device)
        )

        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
            expert_attention_mask, dim=0
        )

        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
        router_per_expert_attention_mask = (
            attention_mask[None, :, :, None]
            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
            .reshape(-1, num_experts)
            .to(compute_device)
        )

        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
            router_per_expert_attention_mask, dim=0
        )

    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts

该函数接收门控网络输出的gate_logits、专家总数num_experts、每个token选择的专家数量top_k,以及可选的attention_mask(用于处理padding)。
具体的步骤为:

1.门控输出(Gate Logits)

if gate_logits is None or not isinstance(gate_logits, tuple):
return 0

if isinstance(gate_logits, tuple):
    compute_device = gate_logits[0].device
    concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)

模型中可能有多层MoE模块,gate_logits是一个元组,包含了每一层MoE的门控输出。这里首先将所有层的gate_logits在第0维度(通常是token维度)上拼接起来。这意味着负载均衡损失是跨所有MoE层、所有token统一计算的。

2.计算路由权重和选择专家

routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

使用softmax将gate_logits转换为路由权重(routing_weights),表示每个token分配到各专家的概率。
再根据routing_weights,为每个token选出概率最高的top_k个专家。最后对selected_experts进行one-hot编码,生成一个掩码,标记了哪些专家被选中。

3.计算核心指标

这是负载均衡损失计算的核心。

# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
    expert_attention_mask, dim=0
)

# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
    attention_mask[None, :, :, None]
    .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
    .reshape(-1, num_experts)
    .to(compute_device)
)

tokens_per_expert反映了每个专家在所有top_k选择中被选中的“密度”或“比例”。如果一个专家被频繁选中(即使不是首选),其对应的值会较高。在Qwen3的实现中,如果attention_mask存在,会先根据attention_mask过滤掉padding token的贡献。
router_prob_per_expert表示门控网络分配给每个专家的平均概率值。同样,如果attention_mask存在,会排除padding token的影响。

4.计算最终损失

overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
    return overall_loss * num_experts

根据公式 Loss = N ⋅ ∑ f i ⋅ P i \text{Loss} = N \cdot \sum f_i \cdot P_i Loss=NfiPi,计算tokens_per_expert和 router_prob_per_expert的点积,并乘以专家数量num_experts,得到最终的辅助损失。
这个计算方式旨在同时鼓励:

  • 1.专家被选中的频率(tokens_per_expert)应该均衡。
  • 2.门控网络对所选专家的置信度(router_prob_per_expert)应该高。

通过将这两者相乘,如果一个专家被频繁选中但门控概率很低,或者门控概率很高但很少被选中,损失都会相应调整。
该损失值越大,说明专家路由越不均衡,模型会通过优化减少这种不均衡。

总结:

Qwen3通过实现Switch Transformer思想的负载均衡损失函数,有效地解决了MoE架构中的专家负载不均问题。该函数通过统计每个专家接收到的token比例以及门控网络分配给各专家的平均概率,构建了一个惩罚项。
这个惩罚项被加到模型的总损失中,在训练过程中引导门控网络学习更均衡的路由策略。这不仅保证了计算资源的有效利用,也促进了各个专家的特化学习,最终有助于提升模型的整体性能和训练稳定性。理解这一机制对于深入掌握和应用MoE大模型至关重要。

[1] Fedus W, Zoph B, Shazeer N. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity[J]. Journal of Machine Learning Research, 2022, 23(120): 1-39.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

G.E.N.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值