论文精读——FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

系列文章目录



Abstract

  Transformer 模型在处理长序列时速度慢且内存消耗大,因为自注意力机制的时间和内存复杂度与序列长度呈二次关系。近似注意力方法试图通过牺牲模型质量来降低计算复杂度,但通常无法实现实际的加速。我们认为,一个缺失的原则是使注意力算法具备 IO 意识——即考虑 GPU 内存不同层级之间的读写操作。我们提出了 FlashAttention,这是一种具备 IO 意识的精确注意力算法,通过分块技术Tiling减少了 GPU 高带宽内存(HBM)和 GPU 芯片上 SRAM 之间的内存读写次数。我们分析了 FlashAttention 的 IO 复杂度,发现它比标准注意力机制需要更少的 HBM 访问次数,并且在一系列 SRAM 大小下是最优的。我们还将 FlashAttention 扩展到块稀疏注意力,得到了一种比任何现有近似注意力方法都更快的近似注意力算法。FlashAttention 比现有基准更快地训练 Transformer 模型:在 BERT-large(序列长度 512)上,与 MLPerf 1.1 训练速度记录相比,端到端实际时间加速了 15%,在 GPT-2(序列长度 1K)上加速了 3 倍,在长距离竞技场(序列长度 1K-4K)上加速了 2.4 倍。FlashAttention 和块稀疏 FlashAttention 使 Transformer 能够处理更长的上下文,从而获得更高质量的模型(GPT-2 的困惑度提高了 0.7,长文档分类提高了 6.4 个百分点),并带来了全新的能力:首次在 Path-X 挑战(序列长度 16K,准确率 61.4%)和 Path-256(序列长度 64K,准确率 63.1%)上实现超越随机性能的 Transformer。


提示:以下是本篇文章正文内容,下面案例可供参考

一 INTROCUTION

  Transformer 模型 [82] 已经成为自然语言处理和图像分类等应用中最广泛使用的架构。Transformer 模型变得越来越大 [5] 和越来越深 [83],但为它们配备更长的上下文仍然很困难 [80],因为其核心的自注意力模块的时间和内存复杂度与序列长度呈二次关系。一个重要的问题是,使注意力更快、更内存高效是否可以帮助 Transformer 模型解决长序列的运行时间和内存挑战。许多近似注意力方法旨在减少注意力的计算和内存需求。这些方法从稀疏近似 [51, 74] 到低秩近似 [12, 50, 84],以及它们的组合 [3, 9, 92]。尽管这些方法将计算需求降低到线性或接近线性与序列长度的关系,但其中许多方法在实际运行时间上并没有比标准注意力更快,因此没有得到广泛采用。
  许多近似注意力方法旨在降低注意力计算的计算量和内存需求。这些方法包括从稀疏近似sparse-approximation [51, 74] 到低秩近似low-rank approximation [12, 50, 84],以及它们的组合 [3, 9, 92]。尽管这些方法将计算需求降低到与序列长度呈线性或接近线性的关系,但其中许多方法与标准注意力相比,并没有实现实际运行时间的加速,也没有得到广泛应用。一个主要原因是,它们专注于减少浮点运算次数(FLOP reduction)(这可能与实际运行速度并无关联),并且往往忽略了内存访问(memory access, IO)带来的开销。

稀疏近似:稀疏近似通过对注意力矩阵进行稀疏化处理来降低复杂度。只计算部分重要位置的注意力,减少了不必要的计算。例如计算 S = Q K T S=QK^T S=QKT时,只计算稀疏模式下非零元素对应的点积运算,而非全部的 N × N N×N N×N次运算,当稀疏律s较高时,可接近线性复杂度
低秩近似:低秩近似利用低秩矩阵来近似表示注意力矩阵。假设注意力矩阵 S S S 的秩为 r r r r ≪ N r \ll N rN ),可以将 S S S 近似分解为两个低秩矩阵的乘积,如 S ≈ U V ⊤ S \approx UV^{\top} SUV ,其中 U ∈ R N × r U \in \mathbb{R}^{N×r} URN×r V ∈ R N × r V \in \mathbb{R}^{N×r} VRN×r 。在计算注意力输出时,不再直接计算 S = Q K ⊤ S = QK^{\top} S=QK 这个 N × N N×N N×N 的矩阵,而是通过对低秩矩阵的运算来近似。计算低秩矩阵的运算量主要取决于 r r r ,计算复杂度变为 O ( N r d ) O(Nrd) O(Nrd) 。由于 r ≪ N r \ll N rN ,复杂度接近线性。

  在本文中,我们认为一个被忽视的原则是使注意力算法具备 IO 感知能力 [1],即仔细考虑对不同层次的快速和慢速内存的读写操作(例如,在快速的 GPU 片上静态随机存取存储器(SRAM)和相对较慢的 GPU 高带宽内存(HBM)之间 [45],见图 1 左)。在现代 GPU 上,计算速度已经超过了内存速度 [61, 62, 63],Transformer 中的大多数操作都受到内存访问的瓶颈限制 [43]。对于类似的受内存限制的操作,IO 感知算法至关重要,因为数据的读写可能占运行时间的很大一部分,如数据库连接 [71]、图像处理 [70]、数值线性代数 [4] 等 [40, 85]。然而,像 PyTorch 和 TensorFlow 这样常见的深度学习 Python 接口并不允许对内存访问进行细粒度的控制。
在这里插入图片描述

1.GPU high bandwidth memory (HBM), HBM访存:HBM 即高带宽内存(High Bandwidth Memory),是一种为 GPU 提供高数据传输带宽的内存。特点是传输相对较快带宽高(仍低于SRAM),内存远大于SRAM。FlashAttention 算法的核心目标之一就是避免在 HBM 中读写注意力矩阵,从而减少 HBM 访问。
2.GPU kernel 内核:GPU(图形处理单元)内核是 GPU 进行并行计算的基本单元。一个 GPU 包含多个内核,这些内核可以同时处理不同的数据块,实现并行计算。FlashAttention 算法通过将所有注意力操作融合到一个 GPU 内核中,实现了更高效的计算。
3.SRAM:SRAM 即静态随机存取存储器(Static Random Access Memory),这里指 GPU 片上的 SRAM。它是位于 GPU 芯片上的高速内存,访存速度快于HBM但是容量小。FlashAttention 算法利用 SRAM 的高速特性,将输入数据分块加载到 SRAM 中进行计算,减少对 HBM 的依赖。
4.IO 感知算法:IO 通常指输入输出(Input/Output),IO 感知算法是指能够考虑到内存访问特性,对数据在不同层次内存之间的读写操作进行优化的算法。在深度学习计算中,不同层次的内存(如 SRAM 和 HBM)在速度和容量上存在差异,IO 感知算法通过合理安排数据的读取和写入顺序、时机,减少不必要的内存访问,提高整体计算效率。
5.softmax reduction 归约:softmax 将一个数值向量转换为表示概率分布的向量。softmax 归约是指对 softmax 函数的计算结果进行某种聚合或处理。在传统注意力机制计算中,softmax 用于计算注意力权重,softmax 归约涉及到对所有注意力权重进行归一化等操作,计算量较大。FlashAttention 算法为了避免在 HBM 中读写注意力矩阵,需要在不访问整个输入的情况下计算 softmax 归约,通过将输入分割成块,逐步对每个块进行 softmax 计算和结果合并,实现了这一目标。
6.tiling 平铺:在 FlashAttention 算法中,平铺是指将输入数据(如 Q、K、V 矩阵)分割成较小的块,然后对这些块进行多次遍历和计算的技术。通过这种方式,每次计算只需要将小块数据加载到高速的 SRAM 中进行处理,而不是一次性处理整个大矩阵,减少了对 HBM 的访问。在计算注意力时,将 Q、K、V 矩阵按一定大小划分为多个子矩阵块,先把 K、V 的子矩阵块加载到 SRAM,再循环加载 Q 的子矩阵块进行计算,处理完一个 Q 的子矩阵块后再处理下一个,逐步完成所有计算,这种方法就是平铺。
7.CUDA:CUDA(Compute Unified Device Architecture)是 NVIDIA 推出的一种并行计算平台和编程模型。它允许开发人员使用 C、C++ 等语言编写在 GPU 上运行的代码,充分利用 GPU 的并行计算能力加速计算任务。通过 CUDA,开发人员可以直接控制 GPU 的硬件资源,实现对内存访问的细粒度控制。FlashAttention 算法借助 CUDA 实现,利用其提供的功能优化内存访问和计算过程,提高算法的运行效率。

  我们提出了 FlashAttention,这是一种新的注意力算法,能够在显著减少内存访问的情况下计算精确的注意力。我们的主要目标是避免在 HBM 中读写注意力矩阵。这需要(i)在不访问整个输入的情况下计算 softmax reduction;(ii)在反向传播中不存储大型的中间注意力矩阵。我们应用两种成熟的技术来应对这些挑战。(i)我们重新组织注意力计算过程,将输入分割成块,并对输入块进行多次遍历,从而逐步执行 softmax 归约(也称为平铺)。(ii)我们存储前向传播中的 softmax 归一化因子,以便在反向传播中在芯片上快速重新计算注意力,这比从 HBM 读取中间注意力矩阵的标准方法更快。我们在 CUDA 中实现了 FlashAttention,以实现对内存访问的细粒度控制,并将所有注意力操作融合到一个 GPU 内核中。尽管由于重新计算导致浮点运算次数增加,但由于 HBM 访问量大幅减少,我们的算法与标准注意力相比,运行速度更快(在 GPT-2 上速度提升高达 7.6 倍 [67],见图 1 右),并且内存使用量与序列长度呈线性关系,比标准注意力更少。
  我们分析了FlashAttention的IO复杂度[1],证明它需要 O ( N 2 d 2 M − 1 ) O(N^{2} d^{2} M^{-1}) O(N2d2M1)次HBM访问,其中d是头维度,M是SRAM的大小,而标准注意力需要 Ω ( N d + N 2 ) \Omega(N d+N^{2}) Ω(Nd+N2)次HBM访问。对于典型的d和M值,FlashAttention与标准注意力相比,所需的HBM访问次数要少很多(如图2所示,最多可少9倍)。此外,我们给出了一个下限,表明对于所有的SRAM大小,没有精确的注意力算法能够在HBM访问次数上渐近地优于FlashAttention。
  我们还表明,FlashAttention 可以作为一个有用的基础,通过克服近似注意力算法在内存访问开销方面的问题,实现其潜在的性能提升。作为概念验证,我们实现了块稀疏 FlashAttention,这是一种稀疏注意力算法,甚至比 FlashAttention 还要快 2 - 4 倍,能够处理长度达 64k 的序列。我们证明块稀疏 FlashAttention 的 IO 复杂度比 FlashAttention 更好,提升倍数与稀疏率成正比。我们将在第 5 节讨论对其他操作(多 GPU 注意力计算、核回归、块稀疏矩阵乘法)的进一步扩展。我们开源了 FlashAttention,以便于基于这个基础进行开发。
  我们通过实验验证了 FlashAttention 能够通过对更长上下文进行建模来加速模型训练并提高模型质量。我们还对 FlashAttention 和块稀疏 FlashAttention 与先前的注意力实现进行了运行时和内存占用memory footprint的基准测试。

  1. 更快的模型训练:FlashAttention 在实际时间上能够更快地训练 Transformer 模型。我们训练 BERT-large(序列长度为 512)的速度比 MLPerf 1.1 [58] 中的训练速度记录快 15%,训练 GPT2(序列长度为 1K)的速度比 HuggingFace [87] 和 Megatron-LM [77] 的基线实现快 3 倍,在长序列竞技场(序列长度为 1K - 4K)上的训练速度比基线快 2.4 倍。
  2. 更高质量的模型:FlashAttention 能够将 Transformer 模型扩展到更长的序列,这提高了模型的质量并赋予了新的能力。我们观察到,在 GPT-2 上,困惑度提高了 0.7,在长文档分类任务 [13] 中,通过对更长序列进行建模,性能提升了 6.4 个百分点。FlashAttention 使得 Transformer 首次能够在 Path-X [80] 挑战中仅通过使用更长的序列长度(16K)就取得优于随机猜测的性能。块稀疏 FlashAttention 使得 Transformer 能够扩展到更长的序列(64K),从而产生了第一个在 Path-256 上能够取得优于随机猜测性能的模型。
  3. 在常见的序列长度(从 128 到 2048)范围内,FlashAttention 的速度比标准注意力实现快最多达 3 倍,并且可扩展到 65536 的序列长度。在序列长度不超过 512 时,FlashAttention 比任何现有的注意力方法都更快且更节省内存。而对于超过 1024 的序列长度,一些近似注意力方法(例如,Linformer)开始变得更快。另一方面,块稀疏 FlashAttention 比我们所知的所有现有近似注意力方法都要快。

2 BACKGROUND

  我们介绍一些关于现代硬件(GPU)上常见深度学习操作的性能特点的背景知识。我们还将描述标准注意力机制的实现方式。

2.1硬件性能

我们这里主要关注GPU。其他硬件加速器的性能与之类似[46, 48]。

  • GPU内存层次结构:GPU内存层次结构(图1左)由多种不同大小和速度的内存组成,内存越小速度越快。例如,A100 GPU拥有40 - 80GB的高带宽内存(HBM),带宽为1.5 - 2.0TB/s,每个108个流式多处理器都有192KB的片上SRAM,其带宽估计约为19TB/s[44, 45]。片上SRAM比HBM快一个数量级,但大小要小很多个数量级。随着计算速度相对于内存速度变得更快[61, 62, 63],操作越来越受到内存(HBM)访问的限制。因此,利用快速的SRAM变得更加重要。
  • 执行模型GPU有大量的线程 threads 来执行一个操作(称为内核)。每个kernel将输入从HBM加载到寄存器和SRAM,进行计算,然后将输出写回HBM。
  • 性能特点:根据计算和内存访问的平衡情况,操作可以分为计算密集型或内存密集型。这通常通过算术强度[85]来衡量,即每字节内存访问的算术运算次数。
    • 计算密集型:操作所花费的时间由算术运算的数量决定,而访问HBM的时间要小得多。典型的例子是具有大内部维度的矩阵乘法,以及具有大量通道的卷积。
    • 内存密集型memory-bound:操作所花费的时间由内存访问的数量决定,而计算所花费的时间要小得多。包括大多数其他操作:逐元素操作(例如,激活函数、随机失活)和归约reduction操作(例如,求和、softmax、批归一化、层归一化)。
  • 内核融合Kernel fusion:加速内存密集型操作最常见的方法是内核融合:如果对同一输入应用多个操作,则可以从HBM加载一次输入,而不是为每个操作多次加载。编译器可以自动融合许多逐元素操作[53, 65, 75]。然而,在模型训练的情况下,中间值仍然需要写入HBM以保存用于反向传播,这降低了简单内核融合的有效性。

2.2标准注意力机制的实现

给定输入序列 Q , K , V ∈ R N × d Q, K, V \in \mathbb{R}^{N×d} Q,K,VRN×d,其中 N N N是序列长度, d d d是头维度,我们想要计算注意力输出 O ∈ R N × d O \in \mathbb{R}^{N×d} ORN×d
S = Q K ⊤ ∈ R N × N , P = s o f t m a x ( S ) ∈ R N × N , O = P V ∈ R N × d S = QK^{\top} \in \mathbb{R}^{N×N}, P = softmax(S) \in \mathbb{R}^{N×N}, O = PV \in \mathbb{R}^{N×d} S=QKRN×N,P=softmax(S)RN×N,O=PVRN×d其中softmax是按行应用的。

  标准注意力机制的实现会将矩阵(S)和(P)存储到HBM中,这需要 O ( N 2 ) O(N^{2}) O(N2)的内存。通常 N ≫ d N \gg d Nd(例如,对于GPT2, N = 1024 N = 1024 N=1024 d = 64 d = 64 d=64)。我们在算法0中描述标准注意力机制的实现。由于部分或大多数操作是内存密集型的(例如,softmax),大量的内存访问导致实际运行时间较长。

  对注意力矩阵应用的其他逐元素操作elementwise operation(如对 S S S应用掩码操作或对 P P P应用随机失活操作)会加剧这个问题。因此,人们进行了许多尝试来融合几个逐元素操作,例如将掩码操作与softmax融合[77]。

Masking: 特定位置的元素设置为负无穷以进行Softmax
dropout:每个神经元以独立的概率失活,因此也是逐元素操作

  在3.2节中,我们将展示标准注意力机制的实现对HBM的访问次数与序列长度(N)的平方成正比。我们还将比较标准注意力机制和我们的方法(FlashAttention)的浮点运算次数(FLOPs)和HBM访问次数。

3 FlashAttention:算法、分析与扩展

  我们展示了如何在减少HBM读写次数且不为反向传播存储大型中间矩阵的情况下计算精确注意力。这产生了一种注意力算法,它在内存使用上更高效,在实际运行时间上也更快。我们分析了其IO复杂度,表明与标准注意力相比,我们的方法需要的HBM访问次数少得多。我们进一步展示了FlashAttention可以作为一个有用的基础,通过扩展它来处理块稀疏注意力。

为便于阐述,我们在此聚焦于前向传播;附录B包含反向传播的详细信息。

3.1 一种基于分块和重计算的高效注意力算法

  给定存储在HBM中的输入 Q , K , V ∈ R N × d Q, K, V \in \mathbb{R}^{N×d} Q,K,VRN×d,我们旨在计算注意力输出 O ∈ R N × d O \in \mathbb{R}^{N×d} ORN×d并将其写回HBM。我们的目标是减少HBM访问量(使其在 N N N上达到亚二次方级别)。

  我们应用两种成熟的技术(分块、重计算)来克服在亚二次方HBM访问量下计算精确注意力的技术挑战。我们在算法1中对此进行描述。主要思想是将输入 Q Q Q K K K V V V分割成块,从慢速的HBM加载到快速的SRAM中,然后根据这些块计算注意力输出。通过在累加之前,用正确的归一化因子对每个块的输出进行缩放,我们最终可以得到正确的结果。

  • 分块tiling:我们按块计算注意力。Softmax会关联 K K K的列,所以我们通过缩放来分解大型的softmax计算[51, 60, 66]。为了数值稳定性,向量 x ∈ R B x \in \mathbb{R}^{B} xRB的softmax计算如下:
    m ( x ) : = m a x i x i , f ( x ) : = [ e x 1 − m ( x ) ⋯ e x B − m ( x ) ] , ℓ ( x ) : = ∑ i f ( x ) i , s o f t m a x ( x ) : = f ( x ) ℓ ( x ) m(x):=max _{i} x_{i}, f(x):=\left[\begin{array}{lll} e^{x_{1}-m(x)} & \cdots & e^{x_{B}-m(x)} \end{array}\right], \ell(x):=\sum _{i} f(x)_{i}, softmax(x):=\frac {f(x)}{\ell (x)} m(x):=maxixi,f(x):=[ex1m(x)exBm(x)],(x):=if(x)i,softmax(x):=(x)f(x)
    对于向量 x ( 1 ) , x ( 2 ) ∈ R B x^{(1)}, x^{(2)} \in \mathbb{R}^{B} x(1),x(2)RB ,我们可以将拼接后的向量 x = [ x ( 1 ) x ( 2 ) ] ∈ R 2 B x=[x^{(1)} x^{(2)}] \in \mathbb{R}^{2B} x=[x(1)x(2)]R2B的softmax分解为:
    m ( x ) = m ( [ x ( 1 ) x ( 2 ) ] ) = m a x ( m ( x ( 1 ) ) , m ( x ( 2 ) ) ) , f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] m(x)=m\left(\left[x^{(1)} x^{(2)}\right]\right)=max \left(m\left(x^{(1)}\right), m\left(x^{(2)}\right)\right), f(x)=\left[e^{m\left(x^{(1)}\right)-m(x)} f\left(x^{(1)}\right) \quad e^{m\left(x^{(2)}\right)-m(x)} f\left(x^{(2)}\right)\right] m(x)=m([x(1)x(2)])=max(m(x(1)),m(x(2))),f(x)=[em(x(1))m(x)f(x(1))em(x(2))m(x)f(x(2))]
    ℓ ( x ) = ℓ ( [ x ( 1 ) x ( 2 ) ] ) = e m ( x ( 1 ) ) − m ( x ) ℓ ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) ℓ ( x ( 2 ) ) , s o f t m a x ( x ) = f ( x ) ℓ ( x ) . \ell (x)=\ell (\left[ x^{(1)} x^{(2)}\right] )=e^{m(x^{(1)})-m(x)}\ell \left( x^{(1)}\right) +e^{m\left( x^{(2)}\right) -m(x)}\ell \left( x^{(2)}\right) , softmax(x)=\frac {f(x)}{\ell (x)}. (x)=([x(1)x(2)])=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2)),softmax(x)=(x)f(x).
    因此,如果我们跟踪一些额外的统计信息( ( m ( x ) , ℓ ( x ) ) (m(x), \ell(x)) (m(x),(x))),就可以一次计算一个块的softmax。我们将输入 Q Q Q K K K V V V分割成块(算法1第3行),计算softmax值以及额外的统计信息(算法1第10行),然后合并结果(算法1第12行)。
  • 重计算Recomputation:我们的目标之一是不为反向传播存储 O ( N 2 ) O(N^{2}) O(N2)的中间值。反向传播通常需要矩阵 S S S P ∈ R N × N P \in \mathbb{R}^{N×N} PRN×N来计算关于 Q Q Q K K K V V V的梯度。然而,通过存储输出 O O O和softmax归一化统计信息 ( m , ℓ ) (m, \ell) (m,),我们可以在反向传播中从SRAM中的 Q Q Q K K K V V V块轻松重新计算注意力矩阵 S S S P P P。这可以看作是一种选择性梯度检查点技术[10, 34]。虽然梯度检查点技术已被提出用于减少所需的最大内存量[66],但据我们所知,所有实现都必须以速度换取内存。相比之下,即使浮点运算次数增加,由于减少了HBM访问,我们的重计算仍加快了反向传播(图2)。完整的反向传播描述见附录B。
  • 实现细节:内核融合:分块使我们能够在一个CUDA内核中实现我们的算法,从HBM加载输入,执行所有计算步骤(矩阵乘法、softmax、可选的掩码和随机失活、矩阵乘法),然后将结果写回HBM(掩码和随机失活见附录B)。这避免了反复从HBM读取输入和向HBM写入输出。
Algorithm 1 FLASHATTENTION
Require: Matrices Q , K , V ∈ R N × d \mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d} Q,K,VRN×d in HBM, on-chip SRAM of size M M M.
1: 设置块大小 B c = ⌈ M 4 d ⌉ B_{c}=\left\lceil\frac{M}{4d}\right\rceil Bc=4dM, B r = min ⁡ ( ⌊ M 4 d ⌋ , d ) B_{r}=\min\left(\left\lfloor\frac{M}{4d}\right\rfloor, d\right) Br=min(4dM,d)
2: 在HBM中初始化 O = ( 0 ) N × d ∈ R N × d \mathbf{O} = (0)_{N\times d} \in \mathbb{R}^{N \times d} O=(0)N×dRN×d, ℓ = ( 0 ) N ∈ R N \ell = (0)_{N} \in \mathbb{R}^{N} =(0)NRN, m = ( − ∞ ) N ∈ R N m = (-\infty)_{N} \in \mathbb{R}^{N} m=()NRN
3: 将 Q \mathbf{Q} Q 划分为 T r = ⌈ N B r ⌉ T_{r}=\left\lceil\frac{N}{B_{r}}\right\rceil Tr=BrN 个块 Q 1 , … , Q T r \mathbf{Q}_{1}, \ldots, \mathbf{Q}_{T_{r}} Q1,,QTr,每个块大小为 B r × d B_{r} \times d Br×d;将 K , V \mathbf{K}, \mathbf{V} K,V 划分为 T c = ⌈ N B c ⌉ T_{c}=\left\lceil\frac{N}{B_{c}}\right\rceil Tc=BcN 个块 K 1 , … , K T c \mathbf{K}_{1}, \ldots, \mathbf{K}_{T_{c}} K1,,KTc V 1 , … , V T c \mathbf{V}_{1}, \ldots, \mathbf{V}_{T_{c}} V1,,VTc,每个块大小为 B c × d B_{c} \times d Bc×d
4: 将 O \mathbf{O} O 划分为 T r T_{r} Tr 个块 O i , … , O T r \mathbf{O}_{i}, \ldots, \mathbf{O}_{T_{r}} Oi,,OTr,每个块大小为 B r × d B_{r} \times d Br×d;将 ℓ \ell 划分为 T r T_{r} Tr 个块 ℓ i , … , ℓ T r \ell_{i}, \ldots, \ell_{T_{r}} i,,Tr,每个块大小为 B r B_{r} Br;将 m m m 划分为 T r T_{r} Tr 个块 m 1 , … , m T r m_{1}, \ldots, m_{T_{r}} m1,,mTr,每个块大小为 B r B_{r} Br
5: 对于 1 ≤ j ≤ T c 1 \leq j \leq T_{c} 1jTc 执行以下操作
6: 从HBM将 K j , V j \mathbf{K}_{j}, \mathbf{V}_{j} Kj,Vj 加载到片上SRAM。
7: 对于 1 ≤ i ≤ T r 1 \leq i \leq T_{r} 1iTr 执行以下操作
8: 从HBM将 Q i , O i , ℓ i , m i \mathbf{Q}_{i}, \mathbf{O}_{i}, \ell_{i}, m_{i} Qi,Oi,i,mi 加载到片上SRAM。
9: 在片上,计算 S i j = Q i K j ⊤ ∈ R B r × B c \mathbf{S}_{ij}=\mathbf{Q}_{i}\mathbf{K}_{j}^{\top} \in \mathbb{R}^{B_{r} \times B_{c}} Sij=QiKjRBr×Bc
10: 在片上,计算 m ~ i j = rowmax ( S i j ) ∈ R B r \tilde{m}_{ij} = \text{rowmax}(\mathbf{S}_{ij}) \in \mathbb{R}^{B_{r}} m~ij=rowmax(Sij)RBr, P ~ i j = exp ⁡ ( S i j − m ~ i j ) ∈ R B r × B c \tilde{\mathbf{P}}_{ij} = \exp(\mathbf{S}_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_{r} \times B_{c}} P~ij=exp(Sijm~ij)RBr×Bc(逐点运算), ℓ ~ i j = rowsum ( P ~ i j ) ∈ R B r \tilde{\ell}_{ij} = \text{rowsum}(\tilde{\mathbf{P}}_{ij}) \in \mathbb{R}^{B_{r}} ~ij=rowsum(P~ij)RBr
11: 在片上,计算 m i new = max ⁡ ( m i , m ~ i j ) ∈ R B r m_{i}^{\text{new}} = \max(m_{i}, \tilde{m}_{ij}) \in \mathbb{R}^{B_{r}} minew=max(mi,m~ij)RBr, ℓ i new = e m i − m i new ℓ i + e m ~ i j − m i new ℓ ~ i j ∈ R B r \ell_{i}^{\text{new}} = e^{m_{i}-m_{i}^{\text{new}}}\ell_{i} + e^{\tilde{m}_{ij}-m_{i}^{\text{new}}}\tilde{\ell}_{ij} \in \mathbb{R}^{B_{r}} inew=emiminewi+em~ijminew~ijRBr
12: 将 O i ← diag ( ℓ i new ) − 1 ( diag ( ℓ i ) e m i − m i new O i + e m ~ i j − m i new P ~ i j V j ) \mathbf{O}_{i} \leftarrow \text{diag}(\ell_{i}^{\text{new}})^{-1}(\text{diag}(\ell_{i})e^{m_{i}-m_{i}^{\text{new}}}\mathbf{O}_{i} + e^{\tilde{m}_{ij}-m_{i}^{\text{new}}}\tilde{\mathbf{P}}_{ij}\mathbf{V}_{j}) Oidiag(inew)1(diag(i)emiminewOi+em~ijminewP~ijVj) 写入HBM。
13: 将 ℓ i ← ℓ i new , m i ← m i new \ell_{i} \leftarrow \ell_{i}^{\text{new}}, m_{i} \leftarrow m_{i}^{\text{new}} iinew,miminew 写入HBM。
14: 结束内层循环
15: 结束外层循环
16: 返回 O \mathbf{O} O

分块和重计算的实现方式
分块(Tiling)的实现
- 块大小确定:算法第1行根据SRAM的大小 M M M 和头维度 d d d 来确定块的大小 B c B_{c} Bc B r B_{r} Br B c = ⌈ M 4 d ⌉ B_{c}=\left\lceil\frac{M}{4d}\right\rceil Bc=4dM 确定了 K \mathbf{K} K V \mathbf{V} V 划分块的列数, B r = min ⁡ ( ⌊ M 4 d ⌋ , d ) B_{r}=\min\left(\left\lfloor\frac{M}{4d}\right\rfloor, d\right) Br=min(4dM,d) 确定了 Q \mathbf{Q} Q 划分块的行数。这样划分是为了让每个块能够尽可能合理地放入SRAM中进行计算,减少对HBM的访问。
- 矩阵划分:第3行和第4行分别将输入矩阵 Q , K , V \mathbf{Q}, \mathbf{K}, \mathbf{V} Q,K,V 以及中间结果矩阵 O \mathbf{O} O, ℓ \ell , m m m 按照确定的块大小进行划分。
- 分块计算:通过两层循环(第5 - 15行),外层循环遍历 K \mathbf{K} K V \mathbf{V} V 的块(每次将一个 K j \mathbf{K}_{j} Kj V j \mathbf{V}_{j} Vj 块从HBM加载到SRAM),内层循环遍历 Q \mathbf{Q} Q 的块(每次将一个 Q i \mathbf{Q}_{i} Qi 块以及对应的中间结果块加载到SRAM)。在片上对每个块进行计算,例如计算 S i j = Q i K j ⊤ \mathbf{S}_{ij}=\mathbf{Q}_{i}\mathbf{K}_{j}^{\top} Sij=QiKj (第9行) 等一系列softmax相关计算,避免一次性处理整个大矩阵,减少对HBM的依赖,实现分块计算注意力。
重计算(Recomputation)的实现
- 避免存储中间值:在反向传播中,通常需要存储矩阵 S \mathbf{S} S P \mathbf{P} P 来计算关于 Q \mathbf{Q} Q, K \mathbf{K} K, V \mathbf{V} V 的梯度,但该算法不存储这些 O ( N 2 ) O(N^{2}) O(N2) 的中间值。
- 存储关键统计信息:通过存储输出 O \mathbf{O} O 和softmax归一化统计信息 ( m , ℓ ) (m, \ell) (m,)(在算法执行过程中不断更新这些统计信息,如第11 - 13行),在反向传播时可以从SRAM中的 Q \mathbf{Q} Q, K \mathbf{K} K, V \mathbf{V} V 块重新计算注意力矩阵 S \mathbf{S} S P \mathbf{P} P。这种方式通过重新计算部分数据,避免了存储大量中间矩阵,实现了重计算,减少了内存占用。

3.2 分析:FlashAttention的IO复杂度

  我们分析了FlashAttention的IO复杂度,表明与标准注意力相比,其HBM访问次数显著减少。我们还提供了一个下限,证明对于所有SRAM大小,没有精确的注意力算法能在HBM访问次数上渐近地优于FlashAttention。证明见附录C。

定理2:设 N N N为序列长度, d d d为头维度, M M M为SRAM大小,且 d ≤ M ≤ N d d ≤ M ≤ Nd dMNd。标准注意力(算法0)需要 Θ ( N d + N 2 ) \Theta(Nd + N^{2}) Θ(Nd+N2)次HBM访问,而FlashAttention(算法1)需要 Θ ( N 2 d 2 M − 1 ) \Theta(N^{2}d^{2}M^{-1}) Θ(N2d2M1)次HBM访问。

  对于典型的 d d d(64 - 128)和 M M M(约100KB)值, d 2 d^{2} d2 M M M小很多倍,因此FlashAttention比标准实现需要的HBM访问次数少很多倍。这导致执行速度更快且内存占用更低,我们将在4.3节中进行验证。

  证明的主要思路是,给定大小为 M M M的SRAM,我们可以每次加载大小为 Θ ( M ) \Theta(M) Θ(M) K K K V V V块(算法1第6行)。对于每块 K K K V V V,我们遍历 Q Q Q的所有块(算法1第8行)来计算中间值,这导致对 Q Q Q进行 Θ ( N d M − 1 ) \Theta(NdM^{-1}) Θ(NdM1)次遍历。每次遍历加载 Θ ( N d ) \Theta(Nd) Θ(Nd)个元素,总计 Θ ( N 2 d 2 M − 1 ) \Theta(N^{2}d^{2}M^{-1}) Θ(N2d2M1)次HBM访问。我们类似地证明标准注意力的反向传播需要 Θ ( N d + N 2 ) \Theta(Nd + N^{2}) Θ(Nd+N2)次HBM访问,而FlashAttention的反向传播需要 Θ ( N 2 d 2 M − 1 ) \Theta(N^{2}d^{2}M^{-1}) Θ(N2d2M1)次HBM访问(附录B)。

  我们证明了一个下限:在计算精确注意力时,对于所有 M M M(SRAM大小)的值,不存在一种算法能使HBM访问次数渐近地优于 O ( N 2 d 2 M − 1 ) O(N^{2}d^{2}M^{-1}) O(N2d2M1)

命题3:设 N N N为序列长度, d d d为头维度, M M M为SRAM大小,且 d ≤ M ≤ N d d ≤ M ≤ Nd dMNd。不存在一种算法,对于所有 M M M [ d , N d ] [d, Nd] [d,Nd]范围内,能以 o ( N 2 d 2 M − 1 ) o(N^{2}d^{2}M^{-1}) o(N2d2M1)次HBM访问计算精确注意力。

o 代表低于这个复杂度的算法 o代表低于这个复杂度的算法 o代表低于这个复杂度的算法

  证明依赖于这样一个事实,即对于 M = Θ ( N d ) M=\Theta(Nd) M=Θ(Nd),任何算法必须执行 Ω ( N 2 d 2 M − 1 ) = Ω ( N d ) \Omega(N^{2}d^{2}M^{-1})=\Omega(Nd) Ω(N2d2M1)=Ω(Nd)次HBM访问。这种在 M M M的子范围内的下限在流算法文献[88]中很常见。我们将以 M M M为参数证明参数化复杂度[27]下限留作有趣的未来工作。

  我们验证了HBM访问次数是注意力运行时间的主要决定因素。在图2(左)中,我们看到尽管FlashAttention由于反向传播中的重计算,其浮点运算次数比标准注意力多,但它的HBM访问次数少得多,导致运行时间快得多。在图2(中),我们改变FlashAttention的块大小 B c B_{c} Bc,这导致不同的HBM访问次数,并测量前向传播的运行时间。随着块大小增加,HBM访问次数减少(因为我们对输入的遍历次数减少),运行时间也减少。对于足够大的块大小(超过256),运行时间随后会受到其他因素(如算术运算)的限制。此外,较大的块大小将无法放入较小的SRAM中。
在这里插入图片描述

3.3 扩展:块稀疏FlashAttention

  我们将FlashAttention扩展到近似注意力:我们提出块稀疏FlashAttention,其IO复杂度比FlashAttention小,减少的倍数与稀疏度成正比。给定输入 Q , K , V ∈ R N × d Q, K, V \in \mathbb{R}^{N×d} Q,K,VRN×d和掩码矩阵 M ~ ∈ 0 , 1 N × N \tilde{M} \in{0,1}^{N×N} M~0,1N×N,我们想要计算:
S = Q K ⊤ ∈ R N × N , P = s o f t m a x ( S ⊙ 1 M ~ ) ∈ R N × N , O = P V ∈ R N × d , S = QK^{\top} \in \mathbb{R}^{N×N}, P = softmax\left(S \odot \mathbb{1}_{\tilde{M}}\right) \in \mathbb{R}^{N×N}, O = PV \in \mathbb{R}^{N×d}, S=QKRN×N,P=softmax(S1M~)RN×N,O=PVRN×d,
其中 ( S ⊙ 1 M ‾ ) k l = S k l (S \odot \mathbb{1}_{\overline{M}})_{kl}=S_{kl} (S1M)kl=Skl如果 M ‾ k l = 1 \overline{M}_{kl}=1 Mkl=1,否则为 − ∞ -\infty 。我们要求 M ˉ \bar{M} Mˉ具有块形式:对于某些块大小 B r B_{r} Br B c B_{c} Bc,对于所有 k k k l l l M ~ k , l = M i j \tilde{M}_{k, l}=M_{ij} M~k,l=Mij,其中 i = ⌊ k / B r ⌋ i=\left\lfloor k / B_{r}\right\rfloor i=k/Br j = ⌊ l / B c ⌋ j=\left\lfloor l / B_{c}\right\rfloor j=l/Bc,且 M ∈ 0 , 1 N / B r × N / B c M \in{0,1}^{N / B_{r}×N / B_{c}} M0,1N/Br×N/Bc

  给定预定义的块稀疏掩码 M ∈ 0 , 1 N / B r × N / B c M \in{0,1}^{N / B_{r}×N / B_{c}} M0,1N/Br×N/Bc,我们可以轻松修改算法1,仅计算注意力矩阵的非零块。该算法与算法1相同,只是我们跳过零块。我们在附录B的算法5中重现了该算法描述。

我们还分析了块稀疏FlashAttention的IO复杂度。

命题4:设 N N N为序列长度, d d d为头维度, M M M为SRAM大小,且 d ≤ M ≤ N d d ≤ M ≤ Nd dMNd。块稀疏FlashAttention(算法5)需要 Θ ( N d + N 2 d 2 M − 1 s ) \Theta(Nd + N^{2}d^{2}M^{-1}s) Θ(Nd+N2d2M1s)次HBM访问,其中 s s s是块稀疏掩码中非零块的比例。

我们看到,应用块稀疏性直接通过稀疏度对IO复杂度中的较大项进行了改进。对于长序列长度 N N N s s s通常设置为 N − 1 / 2 N^{-1 / 2} N1/2[11]或 N − 1 log ⁡ N N^{-1}\log N N1logN [3, 17, 92],导致 Θ ( N N ) \Theta(N \sqrt{N}) Θ(NN ) Θ ( N log ⁡ N ) \Theta(N \log N) Θ(NlogN)的IO复杂度。对于下游实验,我们使用固定的蝶形稀疏模式[17],已证明它能够近似任意稀疏度[16]。

在图2(右)中,我们验证了随着稀疏度增加,块稀疏FlashAttention的运行时间成比例地改善。在LRA基准测试中,块稀疏FlashAttention实现了2.8倍的加速,同时性能与标准注意力相当(第4节)。

4 Experiment

我们评估了使用FlashAttention训练Transformer模型的影响。我们验证了关于训练时间和模型准确性的两个观点,并报告了注意力机制的运行时间和内存基准测试结果。

  • 训练速度:FlashAttention在训练BERT模型时,比MLPerf 1.1 [58]的速度记录快15%;在训练GPT - 2模型时,相较于HuggingFace [87]使用的标准Transformer快3倍,相较于Megatron [77]快1.8倍。FlashAttention使长序列竞技场(LRA)基准测试的速度提升了2.4倍。
  • 质量:FlashAttention可以让Transformer处理更长的序列,从而产生更高质量的结果。FlashAttention训练上下文长度为4K的GPT - 2模型,比Megatron训练上下文长度为1K的GPT - 2模型更快,同时困惑度低0.7 。对更长的序列进行建模,在两个长文档分类任务上的效果提升了6.4个百分点。最后,FlashAttention训练出了首个在具有挑战性的Path - X任务(序列长度16K)上表现优于随机的Transformer模型,并且块稀疏FlashAttention产生了我们所知的首个在Path - 256任务(序列长度64K)上表现优于随机的序列模型。
  • 注意力机制基准测试:我们基于序列长度测量了FlashAttention和块稀疏FlashAttention的运行时间和内存性能。我们证实,FlashAttention的内存占用与序列长度呈线性关系,并且在常见的序列长度(最长2K)下,比标准注意力机制快3倍。我们还证实,块稀疏FlashAttention的运行时间与序列长度呈线性关系,并且比所有现有的近似注意力基线更快。
    更多实验细节见附录E。

4.1 借助FlashAttention实现更快的模型

  • BERT:FlashAttention实现了我们所知的最快的单节点BERT训练速度。我们使用FlashAttention在维基百科数据集上训练一个BERT - large [22]模型。表1将我们的训练时间与英伟达创造MLPerf 1.1 [58]训练速度记录的实现进行了比较。我们的实现快15%。
  • GPT - 2:在大型OpenWebtext数据集[32]上,使用FlashAttention训练GPT - 2 [67]的时间,比广泛使用的HuggingFace [87]和Megatron - LM [77]实现要快。表2显示,与HuggingFace相比,端到端速度提升高达3倍,与Megatron - LM相比提升1.7倍。由于我们没有改变模型定义,FlashAttention实现了与另外两种实现相同的困惑度。附录E包含了整个训练过程中验证困惑度的图表,证实了FlashAttention在数值上与基线一样稳定,并产生了相同的训练/验证曲线。
  • 长序列竞技场:我们在长序列竞技场(LRA [80])基准测试中比较了普通Transformer(采用标准实现或FlashAttention)。我们测量了所有模型的准确性、吞吐量和训练时间。每个任务的序列长度不同,介于1024和4096之间。我们遵循[80]中的实现和实验设置,以及[90]中的变化。表3显示,与标准注意力机制相比,FlashAttention实现了2.4倍的速度提升。块稀疏FlashAttention比我们测试过的所有近似注意力方法都要快。

4.2 利用更长序列构建更好的模型

  • 长上下文语言建模:FlashAttention的运行时间和内存效率使我们能够将GPT - 2的上下文长度增加到原来的4倍,同时运行速度仍比Megatron - LM的优化实现更快。表4显示,使用FlashAttention且上下文长度为4K的GPT - 2,比上下文长度为1K的Megatron版GPT - 2仍快30%,同时困惑度低0.7。与Megatron - LM相比,使用FlashAttention的小版本GPT - 2,上下文长度增加到4倍,速度仍快30%,困惑度低0.7。表中给出了在8个A100 GPU上的训练时间。
  • 长文档分类:使用FlashAttention和更长的序列来训练Transformer模型,可以提升在MIMIC - III [47]和ECtHR [6, 7]数据集上的性能。MIMIC - III包含重症监护病房患者的出院总结,每个总结都标注了多个标签。ECtHR包含来自欧洲人权法院的法律案件,每个案件都对应到被指控违反的《欧洲人权公约》条款。这两个数据集都包含非常长的文本文档;MIMIC数据集中的平均词元数为2395,最长的文档包含14562个词元,而在ECtHR数据集中,平均词元数为2197,最长的文档包含49392个词元。我们从增加预训练RoBERTa模型[56]的序列长度中评估其提升(我们像Beltagy等人[3]中那样重复位置嵌入)。表5显示,在MIMIC数据集上,序列长度为16K时的性能比长度为512时高4.3个百分点;在ECtHR数据集上,长度为8K时的性能比长度为512时高8.5个百分点。这些差异可能是由于细微的分布变化导致的:MIMIC - III包含专业的医学文本,因此可能更容易受到文档长度分布变化的影响,而ECtHR包含的是通用语言。
  • Path - X和Path - 256:Path - X和Path - 256基准测试是长序列竞技场基准测试中具有挑战性的任务,旨在测试长上下文处理能力。任务是判断在一个128×128(或256×256)的黑白图像中,两个点之间是否存在连接路径,图像以每次一个像素的方式输入到Transformer模型中。在先前的研究中,所有Transformer模型要么出现内存不足的情况,要么只能达到随机的性能[80]。人们一直在寻找能够对这种长上下文进行建模的替代架构[37]。我们在此展示了Transformer模型首次能够解决Path - X和Path - 256任务的结果(表6)。我们在Path - 64上对Transformer进行预训练,然后通过对位置嵌入进行空间插值迁移到Path - X任务上。FlashAttention在Path - X任务上达到了61.4%的准确率。此外,块稀疏FlashAttention使Transformer能够处理长度为64K的序列,在Path - 256任务上达到了63.1%的准确率。

4.3 注意力机制基准测试

我们改变序列长度,并在配备40GB HBM的单个A100 GPU上,针对各种注意力基线,测量FlashAttention和块稀疏FlashAttention的运行时间和内存使用情况,测试中包含随机失活和填充掩码。我们将其与精确注意力、近似注意力和稀疏注意力的参考实现进行比较。我们在正文中报告部分基线;附录E包含更多基线和完整细节。

  • 运行时间:图3(左)以毫秒为单位报告了FlashAttention和块稀疏FlashAttention的前向传播 + 反向传播的运行时间,并与精确注意力、近似注意力和稀疏注意力的基线进行了比较(附录E中有精确数值)。运行时间随序列长度呈二次方增长,但FlashAttention的运行速度明显快于精确注意力基线,比PyTorch实现快3倍。许多近似/稀疏注意力机制的运行时间随序列长度呈线性增长,但由于内存访问次数较少,FlashAttention在短序列上的运行速度仍比近似注意力和稀疏注意力快。近似注意力的运行时间在序列长度为512到1024之间开始与FlashAttention的运行时间交叉。另一方面,块稀疏FlashAttention在所有序列长度上都比我们所知的精确注意力、稀疏注意力和近似注意力的所有实现都要快。
  • 内存占用memory footprint:图3(右)展示了FlashAttention和块稀疏FlashAttention与各种精确注意力、近似注意力和稀疏注意力基线相比的内存占用情况。FlashAttention和块稀疏FlashAttention的内存占用情况相同,随序列长度呈线性增长。FlashAttention的内存效率比精确注意力基线高20倍,比近似注意力基线的内存效率也更高。除了Linformer之外,所有其他算法在序列长度达到64K之前,在A100 GPU上就会出现内存不足的情况,而FlashAttention的内存效率仍比Linformer高2倍。
  • 在这里插入图片描述

5 局限性和未来方向

我们讨论该方法的局限性以及未来的研究方向。相关工作见附录A。

  • 编译新的CUDA内核:我们目前构建注意力机制的IO感知实现方法,需要为每个新的注意力实现编写一个新的CUDA内核。这要求以比PyTorch低级得多的语言编写注意力算法,并且需要大量的工程工作。此外,实现可能无法在不同的GPU架构之间移植。这些局限性表明,需要一种支持以高级语言(如PyTorch)编写注意力算法,并将其编译为CUDA中IO感知实现的方法,就像Halide在图像处理中所做的工作一样[70]。
  • IO感知的深度学习:我们认为IO感知方法可以扩展到注意力机制之外。在Transformer中,注意力是最消耗内存的计算,但深度学习中的每一层在某种程度上都涉及到GPU的HBM访问。我们希望我们的工作能启发更多IO感知实现的附加模块。附录D中讨论了这些潜在的扩展。
  • 多GPU的IO感知方法Multi - GPU IO - Aware Methods:我们的IO感知注意力实现,在单个GPU的计算限制内是最优的。然而,注意力计算可以在多个GPU之间并行化[72]。使用多个GPU为IO分析增加了一个额外的层次,即要考虑GPU之间的数据传输。我们希望我们的工作能启发在这个方向上的未来研究。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值