FlashAttention新升级!斯坦福博士一人重写算法,第二代实现了最高9倍速提升。Transformer上下文长度史诗级提升

继超快且省内存的注意力算法FlashAttention爆火后,升级版的2代来了。

FlashAttention-2是一种从头编写的算法,可以加快注意力并减少其内存占用,且没有任何近似值。

比起第一代,FlashAttention-2速度提升了2倍。

甚至,相较于PyTorch的标准注意力,其运行速度最高可达9倍。

一年前,StanfordAILab博士Tri Dao发布了FlashAttention,让注意力快了2到4倍,如今,FlashAttention已经被许多企业和研究室采用,广泛应用于大多数LLM库。

如今,随着长文档查询、编写故事等新用例的需要,大语言模型的上下文以前比过去变长了许多——GPT-4的上下文长度是32k,MosaicML的MPT上下文长度是65k,Anthropic的Claude上下文长度是100k。

但是,扩大Transformer的上下文长度是一项极大的挑战,因为作为其核心的注意力层的运行时间和内存要求,是输入序列长度的二次方。

Tri Dao一直在研究FlashAttention-2,它比v1快2倍,比标准的注意力快5到9倍,在A100上已经达到了225 TFLOP/s的训练速度!

论文地址:https://tridao.me/publications/flash2/flash2.pdf

项目地址:https://github.com/Dao-AILab/flash-attention

FlashAttention-2:更好的算法、并行性和工作分区

端到端训练GPT模型,速度高达225 TFLOP/s

虽说FlashAttention在发布时就已经比优化的基线快了2-4倍,但还是有相当大的进步空间。

比方说,FlashAttention仍然不如优化矩阵乘法(GEMM)运算快,仅能达到理论最大FLOPs/s的25-40%(例如,在A100 GPU上的速度可达124 TFLOPs/s)。

FlashAttention_基准测试

GEMM如何用于卷积

在过去的几个月里,研究人员一直在开发FlashAttention-2,它的性能指标比第一代更强。

研究人员表示,2代相当于完全从头重写,使用英伟达的CUTLASS 3.x及其核心库CuTe。从速度上看,FlashAttention-2比之前的版本快了2倍,在A100 GPU上的速度可达230 TFLOPs/s。

当使用端到端来训练GPT之类的语言模型时,研究人员的训练速度高达225 TFLOPs/s(模型的FLOP利用率为72%)。

对注意力计算重新排序

我们知道,FlashAttention是一种对注意力计算进行重新排序的算法,利用平铺、重新计算来显著加快计算速度,并将序列长度的内存使用量从二次减少到线性。

FlashAttention_人工智能_02

研究人员将输入块从HBM(GPU内存)加载到SRAM(快速缓存),并对该模块执行注意,更新HBM中的输出。

由于没有将大型中间注意力矩阵写入HBM,内存的读/写量也跟着减少,进而带来了2-4倍的执行时间加速。

下图是FlashAttention的前向传递图:通过平铺和softmax重新缩放,研究人员人员按模块进行操作,避免从HBM读取或是写入,同时获得正确输出,无需近似。

FlashAttention_共享内存_03

然而,FlashAttention仍然存在一些低效率的问题,这是由于不同线程块之间的工作划分并不理想,以及GPU上的warp——导致低占用率或不必要的共享内存读写。

更少的non-matmul FLOP(非矩阵乘法浮点计算数)

研究人员通过调整FlashAttention的算法来减少non-matmul FLOP的次数。这非常重要,因为现代GPU有专门的计算单元(比如英伟达GPU上的张量核心),这就使得matmul的速度更快。

例如,A100 GPU FP16/BF16 matmul的最大理论吞吐量为312 TFLOPs/s,但non-matmul FP32的理论吞吐量仅为 19.5 TFLOPs/s。

另外,每个非matmul FLOP比matmul FLOP要贵16倍。

所以为了保持高吞吐量,研究人员希望在matmul FLOP上花尽可能多的时间。

研究人员还重新编写了FlashAttention中使用的在线softmax技巧,以减少重新缩放操作的数量,以及边界检查和因果掩码操作,而无需更改输出。

更好的并行性

FlashAttention v1在批大小和部数量上进行并行化处理。研究人员使用1个线程块来处理一个注意力头,共有 (batch_size * head number) 个线程块。

FlashAttention_基准测试_04

在前向处理(左图)中,研究者将Worker(线程块)并行化,每个Worker负责处理注意力矩阵的一个行块。在后向处理过程中(右图),每个Worker处理注意力矩阵的一个列块

每个线程块都在流式多处理器 (SM)运行,例如,A100 GPU上有108个这样的处理器。当这个数字很大(比如 ≥80)时,这种调度是有效的,因为在这种情况下,可以有效地使用GPU上几乎所有的计算资源。

在长序列的情况下(通常意味着更小批或更少的头),为了更好地利用GPU上的多处理器,研究人员在序列长度的维度上另外进行了并行化,使得该机制获得了显著加速。

更好的工作分区

即使在每个线程块内,研究人员也必须决定如何在不同的warp(线程束)之间划分工作(一组32个线程一起工作)。研究人员通常在每个线程块使用4或8个warp,分区方案如下图所示。

研究人员在FlashAttention-2中改进了这种分区,减少了不同warp之间的同步和通信量,从而减少共享内存读/写。

FlashAttention_人工智能_05

对于每个块,FlashAttention将K和V分割到4个warp上,同时保持Q可被所有warp访问。这称为「sliced-K」方案。

然而,这样做的效率并不高,因为所有warp都需要将其中间结果写入共享内存,进行同步,然后再将中间结果相加。

而这些共享内存读/写会减慢FlashAttention中的前向传播速度。

在FlashAttention-2中,研究人员将Q拆分为4个warp,同时保持所有warp都可以访问K和V。

在每个warp执行矩阵乘法得到Q K^T的一个切片后,它们只需与共享的V切片相乘,即可得到相应的输出切片。

这样一来,warp之间就不再需要通信。共享内存读写的减少就可以提高速度。

新功能:头的维度高达256,多查询注意力

FlashAttention仅支持最大128的头的维度,虽说适用于大多数模型,但还是有一些模型被排除在外。

FlashAttention-2现在支持256的头的维度,这意味着GPT-J、CodeGen、CodeGen2以及Stable Diffusion 1.x等模型都可以使用FlashAttention-2来获得加速和节省内存。

v2还支持多查询注意力(MQA)以及分组查询注意力(GQA)。

FlashAttention_分块_06

GQA为每组查询头共享单个key和value的头,在多头和多查询注意之间进行插值

这些都是注意力的变体,其中多个查询头会指向key和value的同一个头,以减少推理过程中KV缓存的大小,并可以显著提高推理的吞吐量。

注意力基准

研究人员人员在A100 80GB SXM4 GPU 上测量不同设置(有无因果掩码、头的维度是64或128)下不同注意力方法的运行时间。

FlashAttention_基准测试_07

研究人员发现FlashAttention-2比第一代快大约2倍(包括在xformers库和Triton中的其他实现)。与PyTorch中的标准注意力实现相比,FlashAttention-2的速度最高可达其9倍。

FlashAttention_分块_08

A100 GPU上的前向+后向速度只需在H100 GPU上运行相同的实现(不需要使用特殊指令来利用TMA和第四代Tensor Core等新硬件功能),研究人员就可以获得高达335 TFLOPs/s的速度。

FlashAttention_人工智能_09

H100 GPU上的前向+后向速度

当用于端到端训练GPT类模型时,FlashAttention-2能在A100 GPU上实现高达225TFLOPs/s的速度(模型FLOPs利用率为72%)。

与已经非常优化的FlashAttention模型相比,端到端的加速进一步提高了1.3倍。

FlashAttention_人工智能_10

未来的工作

速度上快2倍,意味着研究人员可以用与之前训练8k上下文模型相同的成本,来训练16k上下文长度的模型。这些模型可以理解长篇书籍和报告、高分辨率图像、音频和视频。

同时,FlashAttention-2还将加速现有模型的训练、微调和推理。

在不久的将来,研究人员还计划扩大合作,使FlashAttention广泛适用于不同类型的设备(例如H100 GPU、AMD GPU)以及新的数据类型(例如fp8)。

下一步,研究人员计划针对H100 GPU进一步优化FlashAttention-2,以使用新的硬件功能(TMA、第四代Tensor Core、fp8等等)。

将FlashAttention-2中的低级优化与高级算法更改(例如局部、扩张、块稀疏注意力)相结合,可以让研究人员用更长的上下文来训练AI模型。

研究人员也很高兴与编译器研究人员合作,使这些优化技术更好地应用于编程。

参考资料:

 https://princeton-nlp.github.io/flash-atttention-2/

别再「浪费」GPU了,FlashAttention重磅升级,实现长文本推理速度8倍提升

处理小说、法律文件等长文本是大模型的一个重要应用方向,但也面临速度上的挑战。FlashAttention 作者 Tri Dao 等人提出的「Flash-Decoding」通过充分利用 GPU,可以将大模型的长上下文推理速度提高至 8 倍。

最近,像 ChatGPT 或 Llama 这样的大型语言模型(LLM)引起了前所未有的关注。然而,它们的运行成本仍然极高。虽然生成单个响应可能仅需 0.01 美元(在 AWS 上的 8xA100 实例上运行几秒钟),但当扩大规模以满足数十亿用户的需求时,成本会迅速累积。而且,这些用户可能每天与 LLM 进行多次互动。某些用例的成本更高,例如代码自动生成,因为它会随着每次输入新字符而运行。随着 LLM 应用的不断增加,即使在生成时间方面实现细微的效率提升,也将产生巨大的影响。

LLM 推理(或「解码」)是一个迭代的过程:token 逐个生成。生成包含 N 个 token 的完整句子需要通过模型进行 N 次前向传递。幸运的是,我们可以缓存先前计算的 token:这意味着单个生成步骤不依赖于上下文长度,除了一个单独的操作 —— 注意力。这个操作导致上下文长度不能很好地扩展。

在 LLM 的重要新兴用例中,有一些需要利用更长的上下文。只有拥有了更长的上下文窗口,LLM 才能对更长的文档进行推理,无论是总结文档还是回答其中的问题。此外,它们还可以保持更长的对话历史,甚至在编写代码之前处理整个代码库。举个例子,在 2022 年,大多数 LLM 的上下文长度最多为 2k(例如 GPT-3),但现在,有些开源 LLM 已经可以扩展到 32k(比如 Llama-2-32k),甚至有些模型已经达到了 100k(比如 CodeLlama)。在这些情境中,注意力操作在推理过程中占据了相当大的时间比例。

在扩展 batch size 维度时,即使上下文相对较短,注意力也可能成为一个瓶颈。这是因为随着 batch 维度的增加,需要读取的内存量也会增加,而对于模型的其余部分,内存需求只取决于模型的大小。

为了解决上述问题,FlashAttention 的作者 Tri Dao 等人提出了一项名为「Flash-Decoding」的技术,它显著加速了推理过程中的注意力计算,使长序列的处理生成速度提高到了原来的 8 倍。其主要思想是以最快的速度并行加载键和值,然后分别重新缩放和合并结果,以维持正确的注意力输出。

解码时的多头注意力

在解码期间,生成的每个新 token 都需要关注所有先前的 token,以计算:softmax (queries @ keys.transpose) @ values

这个操作已经在训练阶段通过 FlashAttention 进行了优化(包括最近的 v1 和 v2 版本),瓶颈是读写中间结果的内存带宽(如 Q @ K^T)。然而,这些优化并不直接适用于推理情况,因为瓶颈不同。在训练中,FlashAttention 并行处理 batch size 和查询长度两个维度。而在推理过程中,查询长度通常为 1:这意味着,如果 batch size 小于 GPU 上的流多处理器(streaming multiprocessor,SM)数量(例如 A100 有 108 个),该操作只会利用 GPU 的一小部分!特别是在处理长上下文时,情况尤为明显,因为它需要较小的 batch size 以适应 GPU 内存。当 batch size 为 1 时,FlashAttention 将使用不到 1% 的 GPU!

FlashAttention_人工智能_11

FlashAttention 只在查询块和 batch size 之间并行,并且在解码期间不会设法占用整个 GPU。

使用矩阵乘法基元也能执行注意力计算,这样就不需要使用 FlashAttention 了。在这种情况下,该操作会占用整个 GPU,但会启动许多写入和读取中间结果的内核,因此并不是最优的做法。

更快的注意力解码:Flash-Decoding

新方法 Flash-Decoding 基于 FlashAttention,同时引入了一个新的并行维度:键值序列的长度。它综合了上述两种方法的优点。与 FlashAttention 类似,它在全局内存中存储的额外数据很少。然而,只要上下文足够长,即使 batch size 较小,它也能充分利用 GPU。

FlashAttention_基准测试_12

Flash-Decoding 也在键和值之间并行化,代价是一个小的最终归约(reduction 步骤。

Flash-Decoding 主要有三个工作步骤:

  1. 首先,将键 / 值分成更小的块;
  2. 使用 FlashAttention 并行计算查询与每个这些分块的注意力,为每行和每个分块额外写入一个标量值:注意力值的 log-sum-exp
  3. 最后,通过对所有分块进行归约来计算实际输出,使用 log-sum-exp 来调整每个分块的贡献。

这一切之所以可行,都是因为注意力 /softmax 可以进行迭代计算。在 Flash-Decoding 中,它在两个级别上被使用:在分块内部(类似 FlashAttention),以及跨分块进行最终的归约计算。

实际操作中,步骤(1)不涉及任何 GPU 操作,因为键 / 值块是完整键 / 值张量的视图。然后,有两个独立的核函数,分别用于执行步骤(2)和(3)。

在 CodeLlama 34B 上进行的基准测试

为了验证上述新方法,研究者对 CodeLLaMa-34b 的解码吞吐量进行了基准测试。该模型与 Llama 2 具有相同的架构,一般来说,结果应该适用于许多大型语言模型。研究者在不同序列长度下(从 512 到 64k),以 tok/s 为单位来测量解码速度,并比较了多种计算注意力的方式:

  • Pytorch:使用纯粹的 PyTorch 基元来运行注意力计算(不使用 FlashAttention);
  • FlashAttention v2;
  • FasterTransformer:使用 FasterTransformer 的注意力内核;
  • Flash-Decoding;
  • 以及一个上限值,该值计算了从内存中读取整个模型和 KV-cache 所需的时间

对于非常大的序列,Flash-Decoding 可以将解码速度提高至 8 倍,并且比其他方法的扩展性要好得多。

FlashAttention_基准测试_13

在 prompt 比较小时,所有方法表现接近。但是当序列长度从 512 增加到 64k 时,除了 Flash-Decoding,其他方法的可扩展性都很差。在 Flash-Decoding 的这种模式下(batch size 为 1),扩展序列长度对生成速度的影响很小。

组件级微基准测试

研究者还在 A100 上对多头注意力进行了微基准测试,输入为 f16,考虑了不同的序列长度和 batch size。他们将 batch size 设置为 1,并且使用 16 个 128 维的查询头,以及 2 个键 / 值头(分组查询注意力),这与在 4 个 GPU 上运行的 CodeLLaMa-34b 使用的维度相匹配。

FlashAttention_共享内存_14

上述微基准测试展示了多头注意力的运行时间,单位为微秒。Flash-Decoding 在序列长度扩展到高达 64k 时,几乎实现了恒定的运行时间。

之前测量的高达 8 倍的端到端加速是可能的,因为注意力本身的速度比 FlashAttention 快高达 50 倍。在序列长度达到 32k 之前,注意力的时间大致是恒定的,因为 Flash-Decoding 能够完全利用 GPU。

使用 Flash-Decoding

Flash-decoding 可以在以下链接中找到:

  • FlashAttention 包,从 v2.2 开始:https://github.com/Dao-AILab/flash-attention/tree/main
  • xFormers 包(搜索 xformers.ops.memory_efficient_attention),从 0.0.22 开始:调度程序将根据问题的大小自动使用 Flash-Decoding 或 FlashAttention 方法。当这些方法不受支持时,它可以调度到一个高效的 triton 内核,该内核实现了 Flash-Decoding 算法。

一个完整的使用 LLaMa v2 / CodeLLaMa 的解码示例可以在 FlashAttention  repo 和 xFormers  repo 中找到。此外,作者还提供了一个简单的 LLaMa v1/v2 模型的高效解码代码示例,旨在快速、易读、有教育意义和易于修改。

参考链接:https://princeton-nlp.github.io/flash-decoding/