2025年2月18日,繁忙的一天,Grok3发布的同时DeepSeek和月之暗面同时发布了他们各自最新的论文,而主题直接“撞车”——都是挑战Transformer架构最核心的注意力机制,让它能更高效的处理更长的上下文。而更有趣的是,两家公司的技术派明星创始人的名字出现在各自的论文和技术报告里。
前面拜读了DeepSeek的2502.11089/Native Sparse Attention原生稀疏注意力在长上下文处理和复杂推理任务中展现了显著优势,现在也来拜读一下月之暗面kimi的这篇论文,其代码以开源https://github.com/MoonshotAI/MoBA
核心贡献与创新点
总体框架:
这篇论文提出了混合块注意力(MoBA)来解决 LLMs 在处理长上下文时的效率问题:基于 MoE 原理,将上下文划分为块,利用门控机制让每个查询令牌动态选择历史相关的键值块进行注意力计算。通过计算查询与块的亲和度分数,使用 top-k 门控选择最相关的块,同时通过特定设计保持因果性,并且支持细粒度块分割和与全注意力的混合。
算法设计
- 块划分和选择策略:
首先,将完整上下文划分为多个块,每个块表示后续令牌的一个子集。然后,采用 MoE 的 top-k 门控机制,使每个查询令牌能够选择性地关注不同块的令牌,而不是整个上下文。公式如下:
M
o
B
A
(
q
,
K
,
V
)
=
Softmax
(
q
K
[
I
]
⊤
)
V
[
I
]
\mathrm{MoBA}(q, K, V)=\operatorname{Softmax}\left(q K[I]^{\top}\right) V[I]
MoBA(q,K,V)=Softmax(qK[I]⊤)V[I]
其中,
I
⊆
[
N
]
I \subseteq[N]
I⊆[N] 是选定的键和值的集合,
I
i
I_{i}
Ii 表示第 i 个块的索引范围。
- 门控机制:MoBA 的门控机制首先计算查询令牌与每个块的亲和度分数 s i s_{i} si,然后应用 top -k 门控在所有块中选择最相关的块。公式如下:
g
i
=
{
1
s
i
∈
Topk
(
{
s
j
∣
j
∈
[
n
]
}
,
k
)
0
otherwise
g_{i}=\begin{cases} 1 & s_{i} \in \operatorname{Topk}(\left\{s_{j} \mid j \in[n]\right\}, k) \\ 0 & \text { otherwise } \end{cases}
gi={10si∈Topk({sj∣j∈[n]},k) otherwise
其中,
Topk
(
⋅
,
k
)
\operatorname{Topk}(\cdot, k)
Topk(⋅,k) 表示从每个块计算的亲和度分数中选择前 k 个最高分数。
- 因果性保证:不注意未来的区块,为了保持自回归语言模型的因果性,MoBA 确保查询令牌不能路由到任何未来的块,并在当前块注意力中应用因果掩码。
MoBA的具体实现步骤
通过融合 FlashAttention 和 MoE 的优化技术,实现高效计算,具体步骤包括以下几个关键部分:
-
数据准备:
- 输入的查询(query)、键(key)和值(value)矩阵被表示为 Q , K , V ∈ R N × h × d Q, K, V \in \mathbb{R}^{N \times h \times d} Q,K,V∈RN×h×d,其中 N N N 是序列长度, h h h 是注意力头的数量, d d d 是每个头的维度。
-
块分割:
- 将键和值矩阵分割成多个块。假设上下文长度 N N N 被分成 n n n 个块,每个块包含 B = N n B = \frac{N}{n} B=nN 个位置。每个块表示后续的一组标记。
-
计算门控分数:
- 计算查询与每个块的相关性分数。这通常是通过计算查询与每个块的平均池化后的键的内积来实现的:
K ˉ = mean_pool ( K , B ) \bar{K} = \text{mean\_pool}(K, B) Kˉ=mean_pool(K,B)
S = Q K ˉ ⊤ S = Q \bar{K}^{\top} S=QKˉ⊤
- 计算查询与每个块的相关性分数。这通常是通过计算查询与每个块的平均池化后的键的内积来实现的:
-
选择相关块:
- 使用一个门控机制(例如 top-k 选择)来确定哪些块对查询最相关。门控值
g
i
g_i
gi 用于指示是否选择第
i
i
i 个块:
g i = { 1 if s i ∈ Topk ( { s j ∣ j ∈ [ n ] } , k ) 0 otherwise g_i = \begin{cases} 1 & \text{if } s_i \in \text{Topk}(\{s_j | j \in [n]\}, k) \\ 0 & \text{otherwise} \end{cases} gi={10if si∈Topk({sj∣j∈[n]},k)otherwise
- 使用一个门控机制(例如 top-k 选择)来确定哪些块对查询最相关。门控值
g
i
g_i
gi 用于指示是否选择第
i
i
i 个块:
-
组织注意力模式:
- 根据门控机制的结果,重新排列查询和键值块,以便计算注意力输出。当前块的注意力也需要应用因果掩码以确保自回归性。
-
计算注意力输出:
- 分别计算当前块和其他选定块的注意力输出。可以使用优化过的注意力算法(如 FlashAttention)来提高计算效率。
-
组合结果:
- 将所有注意力输出组合起来,使用在线 Softmax 进行最终的注意力计算。
实验设计
- 数据收集:实验使用了多个广泛使用的长上下文基准数据集,包括 AGAPEval、BBH、CEval、GSM8K、HellaSWAG、Loogle、Competition Math、MBPP、MBPP Santitized、MMLU、MMLU Pro、OpenAI HumanEval、SimpleQA、TriviaQA、LongBench@32K 和 RULER @128K。
- 实验设置:实验从 Llama -3.18B 基础模型开始,逐步增加上下文长度进行持续预训练。MoBA 的块大小设置为 4096,top-K 参数设置为 12,导致高达 95.31% 的注意力稀疏性。在监督微调阶段,采用层间混合策略,最后三层保持全注意力,其余层切换到 MoBA。
- 参数配置:实验中使用了不同大小的模型,包括 568M、822M、1.1B、1.5B 和 2.1B 参数模型,上下文长度从 8K 到 1M 不等。
结果与分析
- 可扩展性:
MoBA 在验证损失上的缩放趋势与全注意力模型非常相似,表明 MoBA 在达到训练最优时具有可比的可扩展性。 - 长上下文可扩展性:
尽管 MoBA 在所有实验中的最后一个块的 LM 损失略高于全注意力模型,但损失差距逐渐缩小,表明 MoBA 具有良好的长上下文可扩展性。 - 混合训练效果:
MoBA/全注意力混合训练模型在位置级 LM 损失上与全注意力模型几乎相同,展示了 MoBA/全注意力混合训练的灵活性和鲁棒性。【将最后几个Transformer层从MoBA切换到全注意力机制,而其余层继续使用MoBA】 - 实际任务性能:
在各种实际任务中,MoBA 的性能与全注意力模型高度可比,特别是在最长的 RULER 基准上,MoBA 的表现接近全注意力模型。 - 效率:MoBA 在所有上下文长度下的前向传播时间均优于全注意力模型,特别是在填充 1M 令牌时,MoBA 实现了高达 6.5 倍的加速比。【MoBA的高效率可以归因于两个关键创新:(1)块稀疏注意力机制;(2)结合了专家混合(MoE)和FlashAttention的优化实现。这些技术有效地解决了全注意力的二次复杂度限制,将计算复杂度降低到了一个更为经济的次二次规模。】
总体结论
这篇论文提出了混合块注意力(MoBA),一种基于 MoE 原理的新型注意力架构,旨在提高 LLMs 在长上下文任务中的效率和可扩展性。MoBA 通过将上下文划分为块并采用动态门控机制,显著降低了计算复杂性,同时保持了模型性能。实验结果表明,MoBA 在保持低 LM 损失的同时,显著提高了计算效率,并在各种基准测试中表现出色。MoBA 的灵活性使其能够与现有模型无缝集成,是一种实用的持续预训练解决方案,增强了 LLMs 的长上下文能力。未来工作可能包括进一步优化 MoBA 的块选择策略,探索其在其他模态中的应用,以及研究其在复杂推理任务中提高泛化的潜力。
优点与创新
- 创新架构:提出了混合块注意力(MoBA)架构,将 MoE 的原理应用于注意力机制,允许模型自主决定注意力位置,而不是依赖预定义的偏差。
- 块稀疏注意力:通过将上下文划分为块并使用门控机制选择最相关的块,显著降低了计算成本,为处理长序列提供了更高效的解决方案。
- 无缝切换:MoBA 能够在全注意力和稀疏注意力之间无缝切换,增强了模型的兼容性和效率,同时不牺牲性能。
- 实验验证:通过广泛的实验,证明了 MoBA 在长上下文任务中的优越性能,并且在计算效率上显著提高。
- 灵活性:MoBA 的设计允许其与现有模型集成,无需大量训练成本,是一种实用的持续预训练解决方案。
- 扩展性:MoBA 在处理长达 1M 的上下文时表现出色,展示了其在长上下文任务中的扩展能力。
关键问题及回答
========
问题 1:MoBA 的门控机制是如何设计的?其具体实现步骤是什么?
MoBA 的门控机制通过计算查询令牌与每个块的亲和度分数来选择性地关注不同块的子集。具体实现步骤如下:
- 块划分:将完整上下文划分为多个块,每个块代表后续的一部分令牌。假设上下文长度 N 可以被块数 n 整除,块大小为 B = N n B =\frac{N}{n} B=nN。
- 亲和度计算:对于每个查询令牌 q q q,计算其与第 i 块的亲和度分数 s i s_{i} si,公式如下:
s
i
=
⟨
q
,
mean_pool
(
K
[
I
i
]
)
⟩
s_{i}=\langle q,\operatorname{mean\_pool}(K[I_{i}])\rangle
si=⟨q,mean_pool(K[Ii])⟩
其中,
mean_pool
(
K
[
I
i
]
)
\operatorname{mean\_pool}(K[I_{i}])
mean_pool(K[Ii]) 表示第 i 块的键向量的平均池化。
3. 门控计算:对所有块应用 top -k 门控,计算查询令牌与每个块的亲和度分数的 top -k 值,公式如下:
g
i
=
{
1
s
i
∈
Topk
(
{
s
j
∣
j
∈
[
n
]
}
,
k
)
0
otherwise
g_{i}=\begin{cases} 1 & s_{i} \in \operatorname{Topk}\left(\{s_{j} \mid j \in [n]\}, k\right) \\ 0 & \text{otherwise} \end{cases}
gi={10si∈Topk({sj∣j∈[n]},k)otherwise
其中,
T
o
p
k
(
⋅
,
k
)
Topk(\cdot, k)
Topk(⋅,k) 表示从所有块的亲和度分数中选择前 k 个最高分数。
4. 块选择:根据门控值
g
i
g_{i}
gi,选择对应的块进行注意力计算。
通过这种门控机制,MoBA 能够动态地选择最相关的块进行注意力计算,从而提高计算效率并保持模型性能。
问题 2:MoBA 在实验中如何验证其长上下文可扩展性?具体的实验设置和结果是什么?
MoBA 通过增加最大序列长度至 32K 来验证其长上下文可扩展性。具体的实验设置和结果如下:
- 数据收集:使用多个广泛使用的长上下文基准数据集,包括 AGAPEval、BBH、CEval、GSM8K、HellaSWAG、Loogle、Competition Math、MBPP、MBPP Santitized、MMLU、MMLU Pro、OpenAI HumanEval、SimpleQA、TriviaQA、LongBench@32K 和 RULER @128K。
- 实验设置:实验开始于 Llama -3.18B 基础模型,逐步增加上下文长度进行持续预训练。MoBA 的块大小设置为 4096,top -K 参数设置为 12,导致高达 95.31% 的注意力稀疏度。在监督微调阶段,采用层间混合策略,最后三层保持全注意力,其余层切换为 MoBA。
- 实验结果:
- 缩放律实验:通过比较使用全注意力和 MoBA 的语言模型在验证集上的损失,发现 MoBA 的损失曲线与全注意力非常相似,损失差异在 1e-3 范围内。这表明 MoBA 在缩放性能上与全注意力相当。
- 长上下文可扩展性:尽管 MoBA 在所有实验中的最后一个块的损失略高于全注意力,但损失差距逐渐缩小,表明 MoBA 具有良好的长上下文可扩展性。
通过这些实验结果,验证了 MoBA 在处理长上下文任务时的有效性和可扩展性。
问题 3:MoBA 在实验中与全注意力模型的性能比较如何?具体的实验结果和结论是什么?
MoBA 在多个实验中与全注意力模型的性能进行了比较,具体的实验结果和结论如下:
- 缩放律实验:在验证集上的损失结果显示,MoBA 的损失曲线与全注意力非常相似,损失差异在 1e-3 范围内。这表明 MoBA 在缩放性能上与全注意力相当。
- 长上下文可扩展性实验:通过增加最大序列长度至 32K,MoBA 的稀疏注意力模式变得更稀疏,达到 95.31% 的稀疏度。尽管 MoBA 在所有实验中的最后一个块的损失略高于全注意力,但损失差距逐渐缩小,表明 MoBA 具有良好的长上下文可扩展性。
- 大规模语言模型评估:在多个长上下文基准任务上,MoBA 的性能与全注意力模型高度可比。特别是在最长的 RULER 基准任务中,MoBA 在稀疏度高达 62.5% 的情况下,性能接近全注意力模型,得分为 0.7818,而全注意力模型的得分为 0.7849。
- MoBA 与全注意力的混合实验:MoBA/全注意力混合训练在长上下文性能上与全注意力模型几乎相同,且在 MoBA 和全注意力之间切换时未观察到显著的损失峰值,展示了 MoBA 的灵活性和鲁棒性。
综上所述,MoBA 在性能上与全注意力相当,但显著提高了计算效率,并且在多个长上下文基准任务中表现出色。MoBA 的灵活性使其能够与现有模型无缝集成,是一种实用的持续预训练解决方案,增强了 LLMs 的长上下文能力。
核心代码解读
这个项目的核心代码moba_attn_varlen_naive
函数进行详细解读。
函数定义及输入参数
def moba_attn_varlen_naive(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
moba_chunk_size: int,
moba_topk: int,
) -> torch.Tensor:
q
、k
、v
:分别是查询(query)、键(key)和值(value)张量,形状为[seqlen, head, head_dim]
。cu_seqlens
:累积序列长度张量,用于表示不同批次的序列长度。max_seqlen
:批次中最大的序列长度。moba_chunk_size
:MoBA块的大小。moba_topk
:选择的前k个块。
主要步骤
1. 初始化参数
batch = cu_seqlens.numel() - 1
softmax_scale = q.shape[-1] ** (-0.5)
o = torch.zeros_like(q)
batch
:计算批次的数量。softmax_scale
:计算softmax缩放因子,用于后续的注意力计算。o
:初始化输出张量,形状与q
相同,初始值为0。
2. 遍历每个批次
for batch_idx in range(batch):
batch_start = cu_seqlens[batch_idx].item()
batch_end = cu_seqlens[batch_idx + 1].item()
# get qkv of this batch
q_ = q[batch_start:batch_end]
k_ = k[batch_start:batch_end]
v_ = v[batch_start:batch_end]
o_ = o[batch_start:batch_end]
- 对于每个批次,根据
cu_seqlens
确定该批次的起始和结束位置,然后提取该批次的q
、k
、v
和o
。
3. 计算键门控权重
key_gate_weight = []
batch_size = batch_end - batch_start
num_block = math.ceil(batch_size / moba_chunk_size)
for block_idx in range(0, num_block):
block_start = block_idx * moba_chunk_size
block_end = min(batch_size, block_start + moba_chunk_size)
key_gate_weight.append(k_[block_start:block_end].mean(dim=0, keepdim=True))
key_gate_weight = torch.cat(key_gate_weight, dim=0) # [ N, H, D ]
- 将当前批次的
k
按moba_chunk_size
分成多个块,计算每个块的均值作为键门控权重,最后将这些权重拼接成一个张量。
4. 计算并掩码门控值
# calc & mask gate
# use fp32 to avoid precision issue in bf16
q_ = q_.type(torch.float32)
key_gate_weight = key_gate_weight.type(torch.float32)
gate = torch.einsum("shd,nhd->hsn", q_, key_gate_weight) # [ H, S, N ]
key_gate_weight = key_gate_weight.type_as(k)
q_ = q_.type_as(k)
for i in range(num_block):
# select the future Qs that can attend to KV chunk i
gate[:, : (i + 1) * moba_chunk_size, i] = float("-inf")
gate[:, i * moba_chunk_size : (i + 1) * moba_chunk_size, i] = float("inf")
- 计算门控值
gate
,通过爱因斯坦求和torch.einsum
实现。 - 对门控值进行掩码操作,确保每个查询只能关注到特定的键值块。
5. 选择前k个块
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=min(moba_topk, num_block), dim=-1, largest=True, sorted=False
)
gate_top_k_val, _ = gate_top_k_val.min(dim=-1) # [ H, S ]
need_attend = gate >= gate_top_k_val.unsqueeze(-1)
- 使用
torch.topk
选择前k个最大的门控值及其索引。 - 确定哪些查询需要关注这些块。
6. 处理门控值
gate_idx_mask = torch.zeros(
need_attend.shape, dtype=torch.bool, device=q.device
)
gate_idx_mask = gate_idx_mask.scatter_(dim=-1, index=gate_top_k_idx, value=True)
need_attend = torch.logical_and(need_attend, gate_idx_mask)
gate[need_attend] = 0
gate[~need_attend] = -float("inf")
gate = gate.repeat_interleave(moba_chunk_size, dim=-1)[:, :, :batch_size] # [ H, S, S ]
gate.masked_fill_(
torch.ones_like(gate, dtype=torch.bool).tril().logical_not(), -float("inf")
)
- 创建一个掩码张量
gate_idx_mask
,用于处理门控值。 - 根据
need_attend
和gate_idx_mask
更新门控值,确保只有前k个块被关注。 - 对门控值进行重复和掩码操作,以适应后续的注意力计算。
7. 计算注意力输出
# calc qk = qk^t
q_ = q_.type(torch.float32)
k_ = k_.type(torch.float32)
v_ = v_.type(torch.float32)
qk = torch.einsum("xhd,yhd->hxy", q_, k_)
# mask
qk += gate
qk *= softmax_scale
# calc o
p = qk.softmax(dim=-1)
o_ += torch.einsum("hxy,yhd->xhd", p, v_)
o = o.type_as(q)
- 计算查询和键的点积
qk
,并添加门控值和缩放因子。 - 使用softmax函数计算注意力权重
p
。 - 计算注意力输出
o_
,并更新输出张量o
。
函数返回值
return o
返回计算得到的注意力输出张量。
完整代码:
def moba_attn_varlen_naive(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
moba_chunk_size: int,
moba_topk: int,
) -> torch.Tensor:
"""Implement the moba brute-force setting for reference
Args:
q (torch.Tensor): [seqlen, head, head_dim],查询张量
k (torch.Tensor): [seqlen, head, head_dim],键张量
v (torch.Tensor): [seqlen, head, head_dim],值张量
cu_seqlens (torch.Tensor): the cumulative sequence length tensor, same definition in flash attn,累积序列长度张量,与Flash Attention中的定义相同
max_seqlen (int): the max sequence length of the batch, same definition in flash attn,批次中的最大序列长度,与Flash Attention中的定义相同
moba_chunk_size (int): MoBA块的大小
moba_topk (int): 选择的top-k值
Returns:
attn_output (torch.Tensor): [seqlen, head, head_dim],注意力输出张量
"""
# qkv shape = [ S, H, D ]
# 计算批次数量,累积序列长度张量的元素数量减1即为批次数量
batch = cu_seqlens.numel() - 1
# 计算softmax缩放因子,为查询张量最后一维大小的负0.5次方
softmax_scale = q.shape[-1] ** (-0.5)
# 初始化输出张量o,形状与查询张量q相同,初始值全为0
o = torch.zeros_like(q)
# 遍历每个批次
for batch_idx in range(batch):
# 获取当前批次的起始和结束位置
batch_start = cu_seqlens[batch_idx].item()
batch_end = cu_seqlens[batch_idx + 1].item()
# 从q、k、v、o中提取当前批次的数据
q_ = q[batch_start:batch_end]
k_ = k[batch_start:batch_end]
v_ = v[batch_start:batch_end]
o_ = o[batch_start:batch_end]
# calc key gate weight
# 初始化键门控权重列表
key_gate_weight = []
# 计算当前批次的大小
batch_size = batch_end - batch_start
# 计算当前批次需要的块数量,向上取整
num_block = math.ceil(batch_size / moba_chunk_size)
# 遍历每个块
for block_idx in range(0, num_block):
# 计算当前块的起始和结束位置
block_start = block_idx * moba_chunk_size
block_end = min(batch_size, block_start + moba_chunk_size)
# 计算当前块的键的均值,作为键门控权重的一部分
key_gate_weight.append(k_[block_start:block_end].mean(dim=0, keepdim=True))
# 将键门控权重列表中的张量按维度0拼接成一个张量,形状为 [ N, H, D ]
key_gate_weight = torch.cat(key_gate_weight, dim=0) # [ N, H, D ]
# calc & mask gate
# 使用fp32数据类型以避免bf16中的精度问题
q_ = q_.type(torch.float32)
key_gate_weight = key_gate_weight.type(torch.float32)
# 计算门控张量,使用爱因斯坦求和约定,形状为 [ H, S, N ]
gate = torch.einsum("shd,nhd->hsn", q_, key_gate_weight) # [ H, S, N ]
# 将键门控权重和查询张量恢复为原始数据类型
key_gate_weight = key_gate_weight.type_as(k)
q_ = q_.type_as(k)
# 遍历每个块
for i in range(num_block):
# 选择未来的查询张量,将不能关注到当前KV块i的位置设为负无穷
gate[:, : (i + 1) * moba_chunk_size, i] = float("-inf")
# 将当前KV块i对应的位置设为正无穷
gate[:, i * moba_chunk_size : (i + 1) * moba_chunk_size, i] = float("inf")
# 计算门控张量的top-k值和索引,k取moba_topk和num_block中的较小值
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=min(moba_topk, num_block), dim=-1, largest=True, sorted=False
)
# 取top-k值中的最小值,形状为 [ H, S ]
gate_top_k_val, _ = gate_top_k_val.min(dim=-1) # [ H, S ]
# 判断哪些位置需要关注,即门控值大于等于top-k最小值的位置
need_attend = gate >= gate_top_k_val.unsqueeze(-1)
# add gate_idx_mask in case of there is cornercases of same topk val been selected
# 初始化门控索引掩码,形状与need_attend相同,初始值全为False
gate_idx_mask = torch.zeros(
need_attend.shape, dtype=torch.bool, device=q.device
)
# 根据top-k索引将对应位置设为True
gate_idx_mask = gate_idx_mask.scatter_(dim=-1, index=gate_top_k_idx, value=True)
# 将need_attend和gate_idx_mask进行逻辑与运算,得到最终需要关注的位置
need_attend = torch.logical_and(need_attend, gate_idx_mask)
# 将需要关注的位置设为0,不需要关注的位置设为负无穷
gate[need_attend] = 0
gate[~need_attend] = -float("inf")
# 将门控张量在最后一维重复moba_chunk_size次,并截取前batch_size列,形状为 [ H, S, S ]
gate = gate.repeat_interleave(moba_chunk_size, dim=-1)[
:, :, :batch_size
] # [ H, S, S ]
# 将门控张量的上三角部分(不包括对角线)设为负无穷
gate.masked_fill_(
torch.ones_like(gate, dtype=torch.bool).tril().logical_not(), -float("inf")
)
# calc qk = qk^t
# 将查询、键、值张量转换为float32数据类型
q_ = q_.type(torch.float32)
k_ = k_.type(torch.float32)
v_ = v_.type(torch.float32)
# 计算查询和键的点积,使用爱因斯坦求和约定,形状为 [ H, S, S ]
qk = torch.einsum("xhd,yhd->hxy", q_, k_)
# mask
# 将门控张量加到qk上
qk += gate
# 将qk乘以softmax缩放因子
qk *= softmax_scale
# calc o
# 对qk进行softmax操作,得到注意力概率
p = qk.softmax(dim=-1)
# 根据注意力概率和值张量计算输出,使用爱因斯坦求和约定,并累加到o_上
o_ += torch.einsum("hxy,yhd->xhd", p, v_)
# 将o的类型转换为与q相同的类型
o = o.type_as(q)
return o