过去几年中,由于长文本分析、音视频生成等应用对Transformer的长序列处理性能的要求日渐提高。如何优化Transformer使之能够处理更长的序列成为了大模型工程化的重要挑战。然而由于注意力层的运行时间和内存会随序列长度的增加呈二次(平方)增加,使得注意力层的开销增加是主要来源,并导致了最终Transformer的性能瓶颈瓶颈。FlashAttention旨在加速注意力计算并减少内存占用。FlashAttention利用底层硬件的内存层次知识,例如GPU的内存层次结构,来提高计算速度和减少内存访问开销。 FlashAttention的核心原理是通过将输入分块并在每个块上执行注意力操作,从而减少对高带宽内存(HBM)的读写操作。FlashAttention系列迭代发展至今已经产生到第三个版本了,在前文中我们已经记录了FlashAttention的相关内容,感兴趣的话可以自行移步阅读:
FlashAttention1:
《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》
FlashAttention2:
《FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning》
本文是自己闲暇时间里对原论文FlashAttention3的阅读记录,感兴趣的话可以参考一下,如果想要深入了解研究内容详情的话,建议移步阅读原英文论文,地址在这里,如下所示:
摘要
作为无处不在的Transformer架构的核心层,注意力机制是大语言模型和长上下文应用的瓶颈。FlashAttention通过最小化内存读/写来加速GPU上的注意力计算。然而,它尚未利用最近硬件中的新功能,FlashAttention-2在H100 GPU上仅实现了35%的利用率。我们开发了三种主要技术来加速Hopper GPU上的注意力计算:利用Tensor Cores和TMA的异步性,通过warp专业化(1)重叠整体计算和数据移动,(2)交错块级矩阵乘法和softmax操作,以及(3)块量化和不连贯处理,利用硬件支持的FP8低精度。我们证明,我们的方法FlashAttention-3在H100 GPU上实现了1.5-2.0倍的加速,FP16达到了高达740 TFLOPs/s(75%的利用率),FP8接近1.2 PFLOPs/s。我们验证了FP8 FlashAttention-3比基线FP8注意力实现了2.6倍的更低数值误差。
1 引言
对于Transformer架构[59],注意力机制构成了主要的计算瓶颈,因为计算查询和键的自注意力分数在序列长度上具有二次缩放。将注意力扩展到更长的上下文将解锁新的能力(对多个长文档[24, 43, 50]和大型代码库中的文件[30, 48]进行建模和推理),新的模态(高分辨率图像[11]、音频[23]、视频[25]),以及新的应用(用户与长历史交互[53],具有长视野的代理工作流[62])。这引发了在长上下文情况下加速注意力的浓厚兴趣,包括通过近似[14, 27, 56]和软件优化([17, 29, 45]),甚至是替代架构[22, 42, 55]。
在这项工作中,我们基于Dao等人[17]的工作,开发了将GPU执行模型和硬件特性知识集成到其高层设计中的精确注意力算法。在[17]中,Dao等人介绍了FlashAttention,这是一种用于并行化注意力的创新分块策略,通过将所有注意力操作融合到一个GPU内核中,消除了对慢速全局内存的中间读/写。Dao[15]将算法重构为FlashAttention-2,也在序列长度维度上并行化,并对键和值矩阵的块执行前向传递的内循环,从而提高了GPU的占用率和工作的分布。然而,我们观察到FlashAttention-2在较新的GPU上相对于优化的矩阵乘法(GEMM)内核仍然实现了较低的利用率,例如在Hopper H100 GPU上为35% vs. 80-90%。部分原因可能是实现层面的差异,例如在针对Tensor Cores时没有使用Hopper特定的指令代替Ampere指令。一些工作,如ThunkerKitten[52]和cuDNN 9[39]表明,通过使用Hopper特定的指令和基于块的抽象,可以加速注意力计算并简化实现。
更根本的是,FlashAttention-2的算法遵循简化的同步模型,在其设计中没有明确使用异步和低精度。异步性是由于硬件专门化加速了ML工作负载中最重要操作的结果:特定的硬件单元执行矩阵乘法(Tensor Cores)或内存加载(Tensor Memory Accelerator - TMA),与执行逻辑、整数和浮点计算的其余CUDA核心分开。低精度,如Hopper中的FP8和Blackwell中的FP4,延续了FP16(2017年Pascal)和BF16(2020年Ampere)的趋势,是一种经过验证的技术,可以在相同的功耗和芯片面积下获得双倍或四倍的吞吐量。我们在第2.2节中回顾了Hopper在这些方向上提供的功能。技术挑战是重新设计FlashAttention-2以利用这些硬件特性:异步性要求即使在矩阵乘法和softmax之间存在依赖关系的情况下也能重叠计算,而低精度则需要特别注意最小化量化误差,特别是在LLM中存在异常值特征的情况下[20, 54]。
为此,我们提出了FlashAttention-3,它贡献并综合了三种新思想,以进一步提高较新GPU架构上的性能:1
Footnote 1: 我们在NVIDIA的Hopper架构的背景下描述了我们的结果。然而,我们的算法适用于任何具有足够强大的异步执行和低精度能力的GPU架构。
-
生产者-消费者异步性:我们定义了一种warp专业化的软件流水线方案,通过将数据的生产者和消费者分成不同的warp,从而扩展了算法隐藏内存和指令发布延迟的能力。
-
在异步块级GEMM下隐藏softmax:我们重叠了softmax中涉及的相对低吞吐量的非GEMM操作(如浮点乘加和指数)与GEMM的异步WGMMA指令。作为其中的一部分,我们重构了FlashAttention-2算法,以规避softmax和GEMM之间的一些顺序依赖关系。例如,在我们的2阶段算法版本中,当softmax对分数矩阵的一个块执行时,WGMMA在异步代理中计算下一个块。
-
硬件加速的低精度GEMM:我们调整前向传递算法,以允许针对FP8 Tensor Cores进行GEMM,几乎使测量的TFLOPs/s翻倍。这需要在内存中如何布局FP32累加器和FP8操作数矩阵的块方面,满足WGMMA的布局一致性要求。我们使用块量化和不连贯处理技术来减轻由于转向FP8精度而导致的精度损失。
为了验证我们的方法,我们在H100 SXM5 GPU上对一系列参数进行了FlashAttention-3的基准测试,并表明(1)FP16在前向传递中比FlashAttention-2快1.5-2.0倍(达到高达740 TFLOPs/s),在后向传递中快1.5-1.75倍,(2)FP8达到了接近1.2 PFLOPs/s,(3)对于长序列长度,FP16优于NVIDIA cuDNN库中的最先进注意力实现,FP8具有竞争力2。我们还验证了FP16 FlashAttention-3产生的数值误差与FlashAttention-2相同,并且由于中间结果(例如softmax重新缩放)保持在FP32中,因此优于标准注意力实现。此外,FP8 FlashAttention-3结合块量化和不连贯处理在存在异常值特征的情况下比标准逐张量量化的注意力准确性高2.6倍。
Footnote 2: 更准确地说,对于头维度64,FlashAttention-3 FP8领先,而对于头维度128和256,它在没有因果掩码的情况下持平,在有因果掩码的情况下落后。
我们以宽松的许可证开源FlashAttention-33,并计划将其集成到PyTorch和Hugging Face库中,以使最多的研究人员和开发者受益。
Footnote 3: FlashAttention-3可在https://github.com/Dao-AILab/flash-attention获取
2 背景:多头注意力和GPU特性
多头注意力
GPU硬件特性和执行模型
我们描述了与FlashAttention-3相关的GPU执行模型方面,重点是NVIDIA Hopper架构作为此模型的具体实例。
表1:NVIDIA Hopper H100 SXM5 GPU的线程-内存层次结构。
内存层次结构:GPU的内存被组织为数据位置的层次结构,容量与带宽成反比(表1)4。全局内存(GMEM),也称为HBM,是所有流多处理器(SMs)可访问的片外DRAM。数据从GMEM透明地缓存到片上L2缓存。接下来,每个SM包含一个小的片上、程序员管理的、高度分区的缓存,称为共享内存(SMEM)。最后,每个SM内有寄存器文件。
Footnote 4: Luo等人[34]报告了每个SM每个时钟周期128字节的共享内存带宽,我们将其乘以132个SM和1830 MHz的提升时钟。
线程层次结构:GPU的编程模型围绕称为线程的执行单元的逻辑分组进行组织。从最细到最粗的层次,线程层次结构由线程、warp(32个线程)、warp组(4个连续的warp)、线程块(即协作线程数组或CTA)、线程块集群(在Hopper中)和网格组成。
这两个层次紧密相连。同一CTA中的线程在同一SM上协同调度,同一集群中的CTA在同一GPC上协同调度。SMEM可由CTA内的所有线程直接寻址,而每个线程最多有256个寄存器(RMEM)专用于自身。
异步性和warp专业化:GPU是吞吐量处理器,依赖并发性和异步性来隐藏内存和执行延迟。对于GMEM和SMEM之间的异步内存复制,Hopper有Tensor Memory Accelerator(TMA)作为专用硬件单元[38, §7.29]。此外,与Ampere等先前架构不同,Hopper的Tensor Core通过warp组范围的WGMMA指令[40, §9.7.14]也是异步的,可以直接从共享内存中获取其输入。
硬件对异步性的支持允许warp专业化的内核,其中CTA的warp被分成生产者或消费者角色,只发布数据移动或计算。通常,这提高了编译器生成最佳指令调度的能力[4]。此外,Hopper通过setmaxnreg[40, §9.7.17.1]支持warp组之间寄存器的动态重新分配,因此执行MMAs的warp可以比仅发布TMA的warp(只需要一个线程)获得更大的RMEM份额。
低精度数字格式:现代GPU有专门的硬件单元来加速低精度计算。例如,WGMMA指令可以针对Hopper上的FP8 Tensor Cores,在每个SM的吞吐量上比FP16或BF16提高2倍。
在注意力的背景下,这些布局限制需要对FP8算法进行某些修改,我们在第3.3节中描述了这些修改。
标准注意力和FlashAttention
根据Dao等人[17],我们将标准注意力表示为在GPU上实现注意力并将中间矩阵S和P具体化为HBM的实现。FlashAttention的主要思想是利用softmax归约的局部版本,以避免这些昂贵的中间读/写,并将注意力融合到一个内核中。局部softmax对应于算法1中消费者主循环的第18-19行以及O块的重新缩放。这一过程确实计算O的简单推导可以在[15, §2.3.1]中找到。
3 FlashAttention-3:算法
在本节中,我们描述了FlashAttention-3算法。为简单起见,我们专注于前向传递,后向传递算法在附录B.1中描述。我们首先指示如何将warp专业化与循环SMEM缓冲区集成到FlashAttention-2的基础算法中。然后我们解释如何利用WGMMA的异步性来定义重叠的GEMM-softmax 2阶段流水线。最后,我们描述了FP8所需的修改,包括布局一致性和通过块量化和不连贯处理实现的精度。
通过warp专业化和乒乓调度实现生产者-消费者异步性
乒乓调度:WGMMA和TMA的异步性,加上warp专业化,提供了重叠一个warp组的softmax计算与另一个warp组的GEMM的机会。为了激发这一点,请注意非矩阵乘法操作在现代硬件加速器上的吞吐量要低得多。例如,H100 SXM5 GPU具有989 TFLOPs的FP16矩阵乘法,但特殊函数(如softmax所需的指数)仅为3.9 TFLOPs5(必要时为softmax)。对于头维度128的FP16注意力前向传递,矩阵乘法FLOPs比指数操作多512倍,但指数的吞吐量低256倍,因此指数可以占用比矩阵乘法50%的周期。情况在FP8中更糟,其中矩阵乘法吞吐量翻倍,但指数吞吐量保持不变。
Footnote 5:CUDA编程指南指定每个流多处理器(SM)每个时钟周期可以执行16次特殊函数操作。我们将16乘以132个SM和1830 MHz时钟速度,得到3.9 TFLOPs的特殊函数。
由于指数是由单独的硬件单元(多功能单元)执行的,理想情况下,我们希望在Tensor Cores执行矩阵乘法时调度指数计算。为此,我们使用同步屏障(bar.sync指令)强制warp组1的GEMM(一个迭代中的PV和下一个迭代中的QKT)在warp组2的GEMM之前调度。结果,warp组1的softmax将在warp组2执行其GEMM时调度。然后角色交换,warp组2执行softmax,而warp组1执行GEMM(因此,“乒乓”调度)。这在图1中进行了说明。尽管在实践中乒乓调度并不像图中描述的那样干净,但我们通常发现这可以提高性能(例如,对于头维度128和序列长度8192的FP16前向传递,从570 TFLOPs提高到620-640 TFLOPs)。
图1:2个warp组的乒乓调度以重叠softmax和GEMM:一个warp组的softmax应在另一个warp组的GEMM运行时调度。相同的颜色表示相同的迭代。
注意力变体:对于多头查询注意力[51]和分组查询注意力[3],我们遵循FlashAttention-2中的方法,调整张量索引以避免在HBM中复制K和V。
在warp组内重叠GEMM和softmax
即使在单个warp组内,我们也可以重叠softmax中的一些指令与GEMM中的一些指令。我们描述了一种实现这一点的方法。
图2:2阶段WGMMA-softmax流水线
在注意力算法中,内循环(主循环)内的操作具有顺序依赖关系,阻碍了单次迭代内的并行化。例如,(局部)softmax(第18至19行)依赖于第一个GEMM的输出,而第二个GEMM将其结果
作为操作数。确实,算法3.1中第17和21行的等待语句序列化了softmax和GEMM的执行。然而,我们可以通过在迭代之间流水线化来打破这些依赖关系,通过寄存器中的额外缓冲区。追求这一想法,我们提出了以下两阶段6 GEMM-softmax流水线算法:
Footnote 6:我们实验了3阶段流水线算法,以进一步重叠第二个WGMMA与softmax。尽管这种方法提供了更高的Tensor Core利用潜力,但由于额外的流水线阶段,需要更多的寄存器,使得块大小和流水线深度之间的权衡更加困难。3阶段算法的详细描述和评估结果见附录B.3。
图1:2个warp组的乒乓调度以重叠softmax和GEMM:一个warp组的softmax应在另一个warp组的GEMM运行时调度。相同的颜色表示相同的迭代。
图2:2阶段WGMMA-softmax流水线
算法2 FlashAttention-3消费者warp组前向传递
算法2作为FP16精度的完整FlashAttention-3算法中消费者路径的替代。在高层次上,我们使用WGMMA作为异步GEMM的代名词。在主循环(第8至16行)内,迭代j的第二个WGMMA操作(第11行)与迭代j+1的softmax操作(第13行)重叠。
虽然上述流水线结构提供了理论性能增益,但有几个实际方面需要考虑:
编译器重新排序:伪代码表示了一个理想化的执行顺序,但编译器(NVCC)经常重新排列指令以进行优化。这可能会破坏精心设计的WGMMA和非WGMMA操作流水线序列,可能导致意外行为或性能增益减弱。对SASS代码的分析表明,编译器按预期生成了重叠代码(附录B.2)。
寄存器压力:为了保持最佳性能,应最小化寄存器溢出。然而,2阶段流水线需要额外的寄存器来存储中间结果并在阶段之间保持上下文。具体来说,必须在寄存器中保留额外的Snext,导致每个线程块的额外寄存器使用量为Br×Bc×sizeof(float)。这种增加的寄存器需求可能与使用更大的块大小(另一种常见优化)冲突,这也是寄存器密集型的。在实践中,应根据分析结果进行权衡。
3阶段流水线:扩展上述描述的2阶段算法,我们提出了一个3阶段变体,进一步重叠第二个WGMMA与softmax。尽管这种方法提供了更高的Tensor Core利用潜力,但由于额外的流水线阶段,需要更多的寄存器,使得块大小和流水线深度之间的权衡更加困难。3阶段算法的详细描述和评估结果见附录B.3。
低精度与FP8
效率:布局转换:在FP8精度下计算FlashAttention-3的前向传递带来了额外的布局一致性挑战,而在FP16中没有遇到这些挑战。
图4:FP8操作数A寄存器WGMMA布局 – 行0和8,线程0-3,条目0-7。
首先,我们注意到输入张量Q、K和V通常在头维度上是连续的,而对于第二个GEMM的FP8 WGMMA,我们需要V(或者更确切地说,加载到SMEM中的V块)在序列长度维度上是连续的。由于TMA加载本身不能改变连续维度,我们需要(1)在GMEM中作为预处理步骤对V进行转置,或(2)在加载到SMEM后在内核中对V块进行转置。为了实现选项(1),我们可以(1a)将转置融合到前一步骤(如旋转嵌入)的结尾,或(1b)调用一个独立的预处理转置内核7以交换序列长度和头维度的步幅。然而,(1a)难以集成到标准库中,而(1b)在内存受限的情况下(如推理)过于浪费。
Footnote 7:优化的转置内核将达到接近设备带宽的速度[46]。
相反,对于FP8 FlashAttention-3,我们选择选项(2)。对于内核转置,我们利用LDSM(ldmatrix)和STSM(stmatrix)指令,这些指令涉及一个warp的线程集体将SMEM加载到RMEM并将RMEM存储到SMEM,粒度为128字节。8 LDSM/STSM指令在寄存器高效的情况下允许我们在生产者warp组中执行它们,并且在进行内存复制时能够转置布局。此外,在第一次迭代后,我们可以安排下一个V块的转置在涉及前一个V块和当前K块的两个WGMMA的阴影中执行。
Footnote 8:在PTX文档中,LDSM/STSM被描述为复制具有16位条目的8×8矩阵[40, §9.7.13.4.15-16],但我们可以在FP8精度的情况下将8位条目打包两次以使用LDSM/STSM。然而,LDSM/STSM的转置版本不能拆分打包的8位条目,这需要在LDSM和STSM之间进行某些寄存器移动以实际执行块级转置;我们省略了细节。
其次,我们观察到与FP16不同,FP8 WGMMA的FP32累加器的内存布局与假设在寄存器中持有的操作数A的布局不同。我们在图3和图4中描述了这两个布局的片段,其中条目按列出的顺序保存在每个线程的寄存器中。通过使用字节置换指令,我们可以将第一个WGMMA的累加器转换为适合第二个WGMMA的格式,并与内核转置生成的V块布局兼容。具体来说,参考图3,我们将顺序更改为:
{d0 d1 d4 d5 d2 d3 d6 d7},
并且这种寄存器置换在每8个字节上重复。在逻辑形状的P块方面,这种操作置换其列(例如,列0189现在成为前四列)。对于WGMMA然后计算正确的输出块,我们可以相应地安排内核转置写出V块的匹配行置换。9
Footnote 9:内核转置提供的这种额外自由消除了使用洗牌指令在不同线程之间改变寄存器所有权的需求,我们之前在[7]中描述了这一点。
我们验证了这些技术在第4.3节中将数值误差降低了高达2.6倍。
4 实验验证
我们使用CUTLASS[57]中的原语(如WGMMA和TMA抽象)来实现FlashAttention-3并评估其效率和准确性。
-
*注意力基准测试:我们测量了FlashAttention-3在不同序列长度下的运行时间,并将其与PyTorch中的标准实现、FlashAttention-2、Triton中的FlashAttention-2(使用H100特定指令)以及cuDNN中针对H100 GPU优化的FlashAttention-2实现进行了比较。我们确认FlashAttention-3比FlashAttention-2快1.5-2.0倍,比Triton中的FlashAttention-2快1.5倍。FlashAttention-3在H100 GPU上达到了高达740 TFLOPs/s,即理论最大TFLOPs/s的75%。
-
*消融研究:我们确认了我们的算法改进(warp专业化和GEMM-softmax流水线)对FlashAttention-3加速的贡献。
-
*FP8注意力的准确性:我们验证了块量化和不连贯处理将FP8 FlashAttention-3的数值误差降低了2.6倍。
注意力基准测试
我们在H100 80GB SXM5 GPU上测量了不同设置(有无因果掩码,头维度64或128)下FP16输入的不同注意力方法的运行时间。我们在图5和图6中报告了结果,显示FlashAttention-3在前向传递中比FlashAttention-2快1.5-2.0倍,在后向传递中快1.5-1.75倍。与标准注意力实现相比,FlashAttention-3可以快达3-16倍。对于中等和长序列(1k及以上),FlashAttention-3甚至超过了针对H100 GPU优化的供应商库(cuDNN - 闭源)的速度。
基准设置:我们将序列长度设置为512、1k、...、16k,并设置批量大小,使得总token数为16k。我们将隐藏维度设置为2048,头维度设置为64、128或256(即32个头、16个头或8个头)。为了计算前向传递的FLOPs,我们使用:
对于因果掩码,我们将这个数字除以2,以考虑大约只有一半的条目被计算。为了得到后向传递的FLOPs,我们将前向传递的FLOPs乘以2.5(因为前向传递中有2个矩阵乘法,后向传递中有5个矩阵乘法,由于重计算)。
我们还在类似设置下测量了FP8前向传递的运行时间。我们在图7中报告了头维度256的结果,并在附录C.2中给出了完整结果。
消融研究:2阶段流水线实验
我们消融了非因果FP16 FlashAttention-3的2阶段WGMMA-softmax流水线和warp专业化,固定参数{batch, seqlen,nheads,hdim} = {4,8448,16,128}。表2中的结果确认了我们的算法改进(异步性与warp专业化和GEMM与softmax的重叠)带来了显著的加速,从570提高到661 TFLOPs。
图5:H100 GPU上的注意力前向速度(FP16/BF16)
数值误差验证
由于对FlashAttention的数值误差[21]感兴趣,我们将FlashAttention-2、FlashAttention-3和标准注意力实现与FP64中的参考实现进行了比较。为了模拟LLM中的异常值特征和激活[54, 20],我们使用以下分布生成Q、K、V的条目:
也就是说,每个条目是均值为零、标准差为1的正态分布,但对于0.1%的条目,我们添加了一个独立项,其标准差为100。然后我们测量表3中的均方根误差(RMSE)。在FP16中,FlashAttention-2和FlashAttention-3都比标准实现实现了1.7倍的更低RMSE,因为中间结果(softmax)保持在FP32中。基线FP8注意力使用逐张量缩放,矩阵乘法累加器为FP32,中间softmax结果保持在FP16中。得益于块量化和不连贯处理,FP8 FlashAttention-3比这个基线准确性高2.6倍。
表2:流水线消融测量
表3:FP16和FP8(e4m3)中的数值误差比较
图6:H100 GPU上的注意力后向速度(FP16/BF16)
图7:H100 GPU上的注意力前向速度(FP8)
5 讨论、局限性、结论
通过FlashAttention-3,我们证明了新的编程技术和硬件特性(如异步性和低精度)可以显著影响注意力的效率和准确性。我们能够将注意力加速1.5-2.0倍,比FlashAttention-2快,并将FP8数值误差降低2.6倍,比标准逐张量量化更准确。我们希望在未来解决的一些工作局限性包括:优化LLM推理,将持久内核设计集成到FP8内核中,10以及理解低精度注意力在大规模训练中的影响。尽管我们在本工作中专注于Hopper GPU,但我们预计这里开发的技术将适用于其他硬件加速器。我们希望更快速和更准确的注意力原语将解锁长上下文任务中的新应用。
Footnote 10:在我们的基准测试中,FP16 FlashAttention-3有一个持久内核和负载均衡策略,而FP8 FlashAttention-3没有。这解释了为什么FP8 FlashAttention-3在短序列长度和因果掩码情况下不如FP8 cuDNN内核表现好。