Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention
Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention
摘要(Abstract)
长上下文建模(Long-context modeling)对于下一代语言模型至关重要,但标准注意力机制(Attention Mechanism)的高计算成本带来了显著的挑战。稀疏注意力(Sparse Attention)为提升效率同时保持模型能力提供了方向。本研究提出了NSA(Natively trainable Sparse Attention),一种可原生训练的稀疏注意力机制,通过算法创新与硬件优化实现高效长上下文建模。NSA采用动态分层稀疏策略,结合粗粒度标记压缩(Token Compression)和细粒度标记选择(Token Selection),保留全局上下文感知和局部精度。主要创新包括:(1) 通过算术强度平衡的算法设计和硬件优化实现显著加速;(2) 支持端到端训练,降低预训练计算成本而不牺牲性能。实验表明,NSA在通用基准、长上下文任务和指令推理中匹配或超越全注意力(Full Attention)模型,同时在64k序列上显著提升解码、前向和反向传播速度。
1. 引言(Introduction)
1.1 长上下文建模的重要性
研究界日益认识到,长上下文建模是下一代大语言模型(Large Language Models, LLMs)的核心能力。这一需求源于多样化的现实应用,例如深入推理(DeepSeek-AI, 2025; Zelikman et al., 2022)、仓库级代码生成(Zhang et al., 2023a)和多轮自主代理系统(Park et al., 2023)。近期突破,如OpenAI的o系列模型、DeepSeek-R1和Gemini 1.5 Pro,展示了处理长文档和复杂推理的能力。然而,传统注意力机制的高复杂度(Zaheer et al., 2020)成为瓶颈,尤其在64k长度上下文解码时,注意力计算占总延迟的70-80%。
1.2 稀疏注意力的潜力与挑战
利用softmax注意力的固有稀疏性是一种自然解决方案(Ge et al., 2023)。现有方法通过KV缓存驱逐(KV-cache Eviction)、块状选择(Blockwise Selection)和聚类方法减少计算开销。然而,这些方法在实际部署中常未达预期加速,且缺乏训练支持,限制了稀疏性优势的充分发挥。
1.3 本文贡献
为解决上述问题,本研究提出了NSA,一种结合硬件对齐和可训练性的稀疏注意力架构。NSA通过分层标记建模降低计算负担,并在全生命周期(训练、推理)中验证了高效性。
图 1| Full Attention 模型与我们的 NSA 模型在性能和效率上的对比。
左图:尽管 NSA 是稀疏的,但它在通用基准测试、长文本任务和推理评估中平均表现超越了 Full Attention 基线模型。
右图:在处理 64k 长度的序列时,NSA 在解码、前向传播和反向传播等所有阶段都实现了显著的计算速度提升,相较于 Full Attention 模型。
图 2 | NSA 架构概览。
左图:该框架通过三个并行的注意力分支处理输入序列:对于给定的查询,前序的键和值被分别处理为压缩注意力(用于粗粒度模式)、选择注意力(用于重要的 token 块)和滑动注意力(用于局部上下文)。
右图:每个分支生成的不同注意力模式的可视化。绿色区域表示需要计算注意力得分的区域,而白色区域表示可以跳过的区域。
2. 重新思考稀疏注意力方法(Rethinking Sparse Attention Methods)
2.1 高效推理的假象(The Illusion of Efficient Inference)
许多稀疏注意力方法虽在理论上减少计算,但在推理延迟上未见显著改善。原因包括:(1) 阶段受限稀疏性,如H2O在解码中应用稀疏但预填充阶段仍需高计算;(2) 与高级架构(如GQA)的内存访问不兼容,导致实际加速受限。
2.2 可训练稀疏性的神话(The Myth of Trainable Sparsity)
现有方法多在推理阶段应用稀疏,忽视训练需求,导致性能下降和训练效率低下。非可训练组件(如ClusterKV的k-means聚类)和低效反向传播进一步加剧了问题。
2.3 本土稀疏性的必要性(Native Sparsity as an Imperative)
上述限制促使我们重新设计NSA,兼顾推理效率和训练可行性。
3. 方法(Methodology)
3.1 背景(Background)
注意力机制在语言建模中通过查询 q t \mathbf{q}_t qt与键 k : t \mathbf{k}_{:t} k:t计算相关性,生成值的加权和 o t \mathbf{o}_t ot:
o t = Attn ( q t , k : t , v : t ) = ∑ i = 1 t α t , i v i ∑ j = 1 t α t , j , α t , i = e q t ⊤ k i d k \mathbf{o}_t = \operatorname{Attn}(\mathbf{q}_t, \mathbf{k}_{:t}, \mathbf{v}_{:t}) = \sum_{i=1}^t \frac{\alpha_{t,i} \mathbf{v}_i}{\sum_{j=1}^t \alpha_{t,j}}, \quad \alpha_{t,i} = e^{\frac{\mathbf{q}_t^\top \mathbf{k}_i}{\sqrt{d_k}}} ot=Attn(qt,k:t,v:t)=i=1∑t∑j=1tαt,jαt,ivi,αt,i=edkqt⊤ki
其中 d k d_k dk为键的特征维度。算术强度(Arithmetic Intensity)影响硬件优化,训练阶段受计算限制,解码阶段受内存带宽限制。
3.2 总体框架(Overall Framework)
NSA将原始键值对替换为更紧凑的表示 κ ^ t , v ^ t \hat{\kappa}_t, \hat{v}_t κ^t,v^t,优化注意力输出:
o t ∗ = ∑ c ∈ { c m p , s l c , w i n } g t c ⋅ Attn ( q t , κ ~ t c , V ~ t c ) \mathbf{o}_t^* = \sum_{c \in \{\mathrm{cmp}, \mathrm{slc}, \mathrm{win}\}} g_t^c \cdot \operatorname{Attn}(\mathbf{q}_t, \tilde{\kappa}_t^c, \tilde{V}_t^c) ot∗=c∈{cmp,slc,win}∑gtc⋅Attn(qt,κ~tc,V~tc)
包括压缩、选择和滑动窗口三种策略,确保稀疏率 N t ≪ t N_t \ll t Nt≪t。
3.3 算法设计(Algorithm Design)
3.3.1 标记压缩(Token Compression)
通过可学习的MLP将键序列块聚合为压缩表示:
κ ^ t c m p = { φ ( k i d + 1 : i d + l ) ∣ 0 ≤ i ≤ ⌊ t − l d ⌋ } \hat{\kappa}_t^{\mathrm{cmp}} = \left\{\varphi(\mathbf{k}_{id+1:id+l}) \mid 0 \leq i \leq \left\lfloor \frac{t-l}{d} \right\rfloor \right\} κ^tcmp={φ(kid+1:id+l)∣0≤i≤⌊dt−l⌋}
其中 l l l为块长度, d d d为滑动步幅。
3.3.2 标记选择(Token Selection)
采用块状选择,利用压缩标记的注意力分数计算重要性:
p t c m p = Softmax ( q t ⊤ K ^ t c m p ) \mathbf{p}_t^{\mathrm{cmp}} = \operatorname{Softmax}(\mathbf{q}_t^\top \hat{K}_t^{\mathrm{cmp}}) ptcmp=Softmax(qt⊤K^tcmp)
保留前 n n n个块的键值对。
3.3.3 滑动窗口(Sliding Window)
维护局部上下文 K ^ t w i n = k t − w : t \hat{K}_t^{\mathrm{win}} = \mathbf{k}_{t-w:t} K^twin=kt−w:t,通过门控机制聚合各分支输出。
3.4 内核设计(Kernel Design)
在Triton上实现硬件优化内核,针对GQA分组加载查询和KV块,显著提升训练和预填充速度。
图 3 | NSA 的核设计
该核通过 GQA 组(Grid Loop)加载查询,获取对应的稀疏 KV 块(Inner Loop),并在 SRAM 上执行注意力计算。绿色块表示 SRAM 上的数据,蓝色块表示 HBM 上的数据
4. 实验(Experiments)
4.1 预训练设置(Pretraining Setup)
实验采用27B参数模型,结合GQA和MoE架构,在270B个8k长度文本上预训练,随后在32k文本上微调。
4.2 性能比较(Performance Comparison)
表1. 通用基准性能
模型 | MMLU | MMLU-PRO | CMMLU | BBH | GSM8K | MATH | DROP | MBPP | HumanEval | 平均 |
---|---|---|---|---|---|---|---|---|---|---|
全注意力 | 0.567 | 0.279 | 0.576 | 0.497 | 0.486 | 0.263 | 0.503 | 0.482 | 0.335 | 0.443 |
NSA | 0.563 | 0.286 | 0.587 | 0.521 | 0.520 | 0.264 | 0.545 | 0.466 | 0.348 | 0.456 |
NSA在大多数任务中优于全注意力,尤其在推理任务中提升显著。
长上下文评估:NSA在64k needle-in-a-haystack测试中实现完美检索,在LongBench上平均得分0.469,优于基线。
思维链推理:NSA-R在AIME 24基准上表现出色,8k和16k上下文分别得0.121和0.146,超越全注意力-R。
5. 效率分析(Efficiency Analysis)
5.1 训练速度
NSA在64k上下文上实现9.0×前向和6.0×反向加速,得益于硬件对齐设计。
5.2 解码速度
表4. 解码内存访问量
上下文长度 | 全注意力 | NSA | 预期加速 |
---|---|---|---|
8192 | 8192 | 2048 | 4× |
65536 | 65536 | 5632 | 11.6× |
NSA显著降低内存访问量,加速随序列长度增加而提升。
6. 结论(Conclusion)
本研究提出了NSA,一种高效的长上下文建模架构。NSA通过分层稀疏设计和硬件优化,在保持性能的同时显著降低计算延迟,推动了稀疏注意力研究的前沿。