昨天简直赶上科技圈的小春晚了,白天阶跃刚发布了业内最先进的视频生成和语音大模型,马斯克刚发布了号称“人类历史上最聪明”的大模型Grok-3,紧接着DeepSeek团队就发布了他们的最新论文,进一步提出了最新的注意力机制,从名字来看又是一个从软件到硬件适配(算子)都做了优化的算法。
论文链接:https://arxiv.org/pdf/2502.11089
从论文的作者列表可以看到DeepSeek的创始人梁文峰赫然在列。
消息一经公开,短短的几个小时内就已经有将近100w的浏览量,也引发了热烈的讨论,DeepSeek目前在业内的影响力可见一斑,感觉已经超过OpenAI了。
那我们也抓紧一探究竟,看看这个最新的注意力机制到底是个什么东西~
老规矩,我们先来看下论文的整体结构。
开门见山,先给出这篇论文的主要贡献,然后再详细展开。
主要贡献
论文提出了原生可训练的稀疏注意力机制(Native Sparse Attention, NSA),将算法创新与硬件优化相结合,实现高效长文本建模,主要贡献体现在方法改进、性能提升、效率优化等方面。
1. 创新稀疏注意力机制设计:提出 NSA,融合动态分层稀疏策略,结合粗粒度token压缩和细粒度token选择,兼顾全局上下文感知与局部精度,改进了传统稀疏注意力设计。通过将键值对重映射,设计了token压缩、token选择和滑动窗口三种策略,构建了完整的算法框架。
2. 实现硬件对齐与训练感知优化:从硬件对齐系统和训练感知设计两方面优化。针对现代硬件优化块稀疏注意力,平衡算术强度,提高硬件利用率;基于triton设计了高效的注意力算子,实现稳定的端到端训练,减少预训练计算量且不牺牲模型性能。
3. 提升模型性能表现:在多个基准测试中,NSA 预训练模型性能与全注意力模型相当甚至超越。在通用基准测试、长上下文任务和基于指令的推理任务中表现出色,尤其在推理相关基准测试中有显著提升,验证了其作为通用架构的稳健性。
4. 显著提高计算效率:在处理 64k 长度序列时,NSA 在解码、前向传播和反向传播阶段均比完整多头注意力机制有大幅加速,且序列越长加速比越高。训练阶段 64k 上下文长度下,前向加速达 9.0 倍,反向加速达 6.0 倍;解码阶段在 64k 上下文长度下,速度提升最高可达 11.6 倍。
用人话来说就是DeepSeek又发布了一种新的模型结构,可以比传统模型更好更快的处理更长的上下文。
摘要
长上下文建模对于下一代语言模型至关重要,然而标准注意力机制的高计算成本带来了巨大的计算挑战。稀疏注意力为在保持模型能力的同时提高效率提供了一个有前景的方向。我们提出了原生可训练稀疏注意力机制(NSA),它将算法创新与硬件适配优化相结合,以实现高效的长上下文建模。NSA采用动态分层稀疏策略,将粗粒度token压缩与细粒度token选择相结合,既保留了全局上下文感知,又保证了局部精度。
我们的方法通过两项关键创新推进了稀疏注意力设计:
(1)通过算术强度平衡的算法设计,并针对现代硬件进行实现优化,我们实现了显著的加速。
(2)我们实现了端到端训练,在不牺牲模型性能的情况下减少了预训练计算量。
如图1所示,实验表明,使用NSA预训练的模型在通用基准测试、长上下文任务和基于指令的推理中,性能与全注意力模型相当甚至超越。同时,在处理64k长度序列时,NSA在解码、前向传播和反向传播方面比全注意力机制实现了大幅加速,验证了其在模型整个生命周期中的高效性。
1. 方法
1.1 背景回顾
正好再回顾下最普通的注意力机制是什么样的,注意力机制在语言建模中应用广泛,每个查询 token 会计算与所有前文 key
的相关度分数,进而生成值
的加权和。对于长度为
的输入序列,注意力操作可以定义如下:
其中Attn()代表注意力计算公式:
其中是q和k的乘积得分,d是特征维度。随着序列长度增加,注意力计算在整体计算成本中占比越来越大,给长上下文处理带来巨大挑战。
算术强度(Arithmetic Intensity )是计算操作数与内存访问数的比值,它本质上影响着硬件上的算法优化。每个 GPU 都有一个由峰值计算能力和内存带宽决定的关键算术强度,通过这两个硬件限制的比值计算得出。对于计算任务而言,算术强度高于该关键阈值时,计算受限于 GPU 的每秒浮点运算次数(FLOPS);低于该阈值时,则受限于内存带宽。
具体到因果自注意力机制,在训练和预填充阶段,批量矩阵乘法和注意力计算展现出较高的算术强度,使得这些阶段在现代加速器上属于计算受限型。相反,自回归解码受内存带宽限制,因为它每次前向传递仅生成一个 token,但却需要加载整个键值缓存,导致算术强度较低。这就产生了不同的优化目标:在训练和预填充阶段降低计算成本,在解码阶段减少内存访问。
1.2 整体框架
为利用具有自然稀疏模式的注意力的潜力,我们提议用一组更紧凑且信息更密集的表示键值对,
、 替换公式 (1) 中的原始键值对
、
,这组新的键值对是针对每个查询
构建的。具体而言,我们正式将优化后的注意力输出定义如下:
其中,、
基于当前查询
和上下文记忆
、
动态构建。我们可以设计多种映射策略来得到不同类别的
、
,并按如下方式组合:
如图 2 所示,NSA 有三种映射策略 ,分别代表键值对的压缩、选择和滑动窗口策略。
是对应策略
的门控分数,通过 MLP 和 sigmoid 激活函数从输入特征中导出。用
表示重新映射后的键 / 值总数:
我们通过确保 来维持较高的稀疏率。
1.3 算法设计
Token Compression (token 压缩)
通过将连续的 key 或 value 块聚合为块级表示,我们得到能捕捉整个块信息的压缩 key 和 value。压缩 key 表示定义如下:
其中, 是块长度,
是相邻块之间的滑动步长,
是一个可学习的 MLP,带有块内位置编码,用于将块中的 key 映射为单个压缩 key。
是由压缩 key 组成的张量。通常,我们采用
来减少信息碎片化。类似的公式适用于压缩 value 表示
。压缩表示捕捉更粗粒度的高级语义信息,降低了注意力的计算负担。
Token Selection (token 选择)
仅使用压缩后的 key 和 value 可能会丢失重要的细粒度信息,因此我们设计了高效的 token 选择机制,以较低的计算开销识别并保留最相关的 token。
1)Blockwise Selection (块级选择)
我们的选择策略以空间连续的块为单位处理 key 和 value 序列,这是出于硬件效率和注意力分数固有分布模式两方面的考虑。块级选择对于在现代 GPU 上实现高效计算至关重要,因为现代 GPU 架构对连续块访问的吞吐量明显高于基于随机索引的读取,并且块级计算能实现对张量核心的最优利用。这种架构特性使得块级内存访问和计算成为高性能注意力实现的基本原则,例如 FlashAttention 的基于块的设计。块级选择也遵循注意力分数的固有分布模式,先前研究表明注意力分数通常呈现空间连续性,即相邻的 key 往往具有相似的重要性级别。
2)Importance Score Computation (重要性分数计算)
计算块重要性分数可能会带来显著的计算开销。幸运的是,压缩 token 的注意力计算会产生中间注意力分数,我们可以利用这些分数来推导选择块的重要性分数,公式如下:
其中, 是
与压缩 key
之间的注意力分数。设
为选择块大小,当压缩块和选择块的分块方案相同时,即
,我们可以直接得到选择块重要性分数
。当分块方案不同时,我们根据它们的空间关系推导选择块的重要性分数。假设
且
,则有:
其中, 表示访问向量元素的索引运算符。对于采用 GQA 或 MQA(其中键值缓存跨查询头共享)的模型,必须确保跨这些头的一致块选择,以最小化解码期间的键值缓存加载。同一组中跨头的共享重要性分数正式定义为:
其中,上标 表示头索引,
是每组中的查询头数量。这种聚合确保了同一组内跨头的一致块选择。
3)Top-n Block Selection (top-n 块选择)
获得选择块重要性分数后,我们保留按块重要性分数排名前 的稀疏块中的 token,公式如下:
其中, 表示降序排名位置,
对应最高分数,
是所选块的索引集,
表示拼接操作。
是由压缩 key 组成的张量。类似的公式适用于细粒度 value
。然后,所选的 key 和 value 按照公式 (5) 与
参与注意力计算。
Sliding Window (滑动窗口)
在注意力机制中,局部模式通常适应更快,可能会主导学习过程,这可能会阻碍模型有效地从压缩和选择的 token 中学习。为解决此问题,我们引入专门的滑动窗口分支来显式处理局部上下文,使其他分支(压缩和选择)能够专注于学习各自的特征,而不会被局部模式干扰。具体来说,我们在大小为 的窗口中保留最近的 token
,
,并将不同信息源(压缩 token、选择 token 和滑动窗口)的注意力计算分离到不同分支。然后,这些分支的输出通过学习的门控机制进行聚合。为进一步防止不同注意力分支之间的干扰,同时尽量减少计算开销,我们为三个分支提供独立的 key 和 value。这种架构设计通过防止局部和长距离模式识别之间的梯度干扰,实现了稳定的学习,且引入的开销极小。
获得所有三类 key 和 value(,
;
,
;以及
,
)后,我们按照公式 (5) 计算最终的注意力输出。结合上述的压缩、选择和滑动窗口机制,构成了 NSA 完整的算法框架。
1.4 Kernel Design (算子kernel设计)
为在训练和预填充阶段实现与 FlashAttention 相当的加速效果,我们基于 Triton 实现了硬件适配的稀疏注意力内核。鉴于多头注意力(MHA)在解码时内存消耗大且效率低,我们遵循当前最先进的大语言模型,专注于具有共享键值缓存的架构,如 GQA 和 MQA。虽然压缩和滑动窗口注意力计算很容易与现有的 FlashAttention - 2 内核兼容,但我们为稀疏选择注意力设计了专门的内核。如果我们遵循 FlashAttention 将时间连续的查询块加载到静态随机存取存储器(SRAM)的策略,由于块内的查询可能需要不连续的键值块,会导致内存访问效率低下。为解决此问题,我们的关键优化在于采用不同的查询分组策略:对于查询序列上的每个位置,我们将 GQA 组内的所有查询头(它们共享相同的稀疏键值块)加载到 SRAM 中。图 3 展示了我们的前向传递实现。所提出的内核架构具有以下关键特征:
a)以组为中心的数据加载:对于每个内循环,加载组中位置 处所有头的查询
及其共享的稀疏键(key) / 值(value)块索引
。
b)共享键值获取:在内循环中,将由 索引的连续键 / 值块顺序加载到 SRAM 中,分别记为 ,
,以最小化内存加载,其中
是满足
的内核块大小。
c)网格外循环:由于不同查询块的内循环长度(与所选块数量 成比例)几乎相同,我们将查询 / 输出循环放入 Triton 的网格调度器中,以简化和优化内核。
这种设计通过以下方式实现了接近最优的算术强度:(1)通过组间共享消除冗余的键值传输;(2)平衡 GPU 流式多处理器之间的计算负载。
2. 实验
NSA在LongBench上与当前主流的稀疏注意力模型以及完整的注意力机制做了对比,下表可以看到,在大部分情况下,表现都是最好的。
在通用基准测试中,对比 NSA 和全注意力模型在知识、推理和编码能力相关基准测试中的表现,NSA 在多数指标上优于全注意力模型;在长上下文评估中,通过 64k 上下文的 “大海捞针” 测试和 LongBench 基准测试,NSA 在长上下文任务中表现出色;在思维链推理评估中,通过知识蒸馏和监督微调,在 AIME 基准测试中,NSA - R 的准确性高于全注意力 - R
3. 效率分析
在配备 8 个 A100 GPU 的系统上,将 NSA 的计算效率与全注意力机制进行对比评估。在效率分析中,同样将模型配置为 GQA 分组数(g = 4),每组头数(h = 16),查询(query) / 键(key)维度 (),值(value)维度(
)。按照第 4 节中的相同设置,设定 NSA 的压缩块大小(l = 32),滑动步长(d = 16),选定块大小(l' = 64),选定块数量(n = 16)(包括固定激活的 1 个初始块和 2 个局部块),滑动窗口大小(w = 512)。
3.1 训练速度
将基于Triton实现的NSA注意力机制和基于Triton的 FlashAttention-2 实现的MHA进行比较,以确保在相同后端下进行公平的速度对比。如图 6 所示,随着上下文长度的增加,NSA 的加速比逐渐增大,在上下文长度为 64k 时,前向传播速度提升可达 9.0 倍,反向传播速度提升可达 6.0 倍。值得注意的是,序列越长,速度优势越明显。这种加速得益于我们为最大化稀疏注意力架构效率而进行的硬件对齐算法设计:(1)按块的内存访问模式通过合并加载最大化了张量核心的利用率;(2)内核中精细的循环调度消除了冗余的键值(KV)传输。
3.2 解码速度
注意力机制的解码速度主要由内存访问瓶颈决定,这与键值缓存(KV cache)的加载量密切相关。在每个解码步骤中,NSA 最多只需加载个压缩 token、nl'个选定 token 和w个相邻 token,其中s是缓存的序列长度。如表 4 所示,随着解码长度的增加,NSA的方法延迟显著降低,在上下文长度为 64k 时,速度提升最高可达 11.6 倍。这种内存访问效率的优势也随着序列变长而更加显著。
再次感叹,现在的技术发展的实在太快了,自从年前DeepSeek-R1发布之后,彷佛整个AI的研究都被按下了快进键,只有加大学习力度才不会被淘汰。