我相信你已经听过很多次:“注意力机制在序列长度N上的扩展性很差,具体来说是O(N²)”。但是,究竟是什么在序列长度上表现不佳?为什么它会表现不佳?真的是序列长度的问题吗?在这个系列中,我们将深入探讨Tri Dao等人(2022年)提出的FlashAttention论文,并探索为什么注意力机制在现代GPU上如此缓慢和资源密集,以及FlashAttention是如何解决这个问题的。
文章分成两部分。
在第一部分中,我们将介绍注意力机制、GPU架构和CUDA编程模型,解释为什么注意力机制在现代GPU上如此缓慢和资源密集,以及是什么阻碍了我们应用CUDA编程中已知的常见优化方法。
在第二部分中,我们将深入探讨FlashAttention,基于第一部分中获得的知识,逐步剖析前向传播和反向传播的算法细节。
🚀 不多废话,让我们开始吧!
文章目录
1. 为什么我们要关心这个问题?
在深入细节之前,让我们先花点时间思考一下,为什么我们要关心这个话题。注意力机制是Transformer架构的核心组件,而Transformer架构又是当今深度学习各个领域的基础。随着模型变得越来越大,训练和推理它们的成本也越来越高,无论是时间还是金钱。因此,如果我们能让注意力机制更快、更高效,我们就可以用更少的时间和金钱训练和推理更大的模型,甚至可以训练以前无法训练的模型。
下图很好地说明了这一点。它展示了FlashAttention与之前方法的性能对比,以及它在序列长度N上的扩展能力。
图1:FlashAttention相较于之前方法的改进。
如你所见,FlashAttention比之前的方法快得多,并且可以处理之前方法因内存不足而无法处理的序列长度。
有了这个认识,让我们先简要介绍一下注意力机制。
2. 注意力机制简介
注意力机制是由Vaswani等人在2017年的论文《Attention is All You Need》中提出的Transformer架构的核心组件。从高层次来看,注意力机制允许模型在处理某个token时,动态地关注输入序列的不同部分,同时忽略其他部分。Token是输入序列中的单个元素,可以是一个单词、图像中的一个像素块,甚至是图中的一个节点。
对于输入序列中的每个token,都会计算其与其他所有token的注意力分数。
最初,注意力机制是为处理文本序列而引入的,但后来它被广泛应用于各种任务和数据模态,如图像、视频和图,成为捕获序列中元素之间依赖关系的通用机制。
为了帮助你理解,让我们来看一个简单的例子:
图2:注意力机制在文本、图像和图中的直观解释。
简单来说,我喜欢把注意力机制想象成一种向数据提问的机制,而注意力分数就是问题的答案。注意力分数越高,说明该信息对当前token越重要。多头注意力(MHA)实际上就是同时为同一个token提出多个问题,然后将答案结合起来。稍后我们会详细讨论这一点。
有了这个直观的认识,让我们更详细地看看Transformer编码器架构和注意力机制。
图3:将注意力机制放入Transformer编码器的上下文中。
Transformer编码器由多个相同的阶段堆叠而成。每个阶段由两个子阶段组成:多头自注意力(MHA)和多层感知机(MLP)。
MHA子阶段实现了多个并行的注意力头,然后将它们连接起来,并线性变换到预期的隐藏维度。
深入到多头注意力层的单头架构中,我们可以看到注意力机制的架构,它可以用以下公式描述:
图4:注意力机制公式。
注意,Dropout和掩码是可选的,但由于FlashAttention提到了它们,我也会在这里提到。
由于这篇文章的重点不是注意力机制本身,因此我不会深入探讨细节。如果你对注意力机制感兴趣,我推荐阅读Vaswani等人的原始论文或Jay Alammar的著名博客文章。
我们感兴趣的是为什么注意力机制在现代GPU上如此缓慢和资源密集。为了理解这一点,让我们先来看看注意力机制中涉及的矩阵形状。
图5:注意力机制中涉及的矩阵形状。
关键点在于,注意力机制在中间结果中生成了NxN矩阵,而我们感兴趣的输出形状仅为Nxd,其中N是序列长度,d是隐藏维度。这就是Transformer的二次计算和内存复杂度O(N²)的原因,其中N是序列长度。虽然这个例子中提到的典型值是N=4,096(当时的情况),但论文将其扩展到N=65,536,这将导致中间结果达到42.9亿个值!
注意,图5展示的是单层中单头的注意力机制。在实际应用中,Transformer使用了多个头和多个层,这进一步大幅增加了计算复杂度。
为了进一步理解为什么注意力机制在现代GPU上如此缓慢和资源密集,我们需要先简要了解一下GPU的硬件架构以及CUDA编程的基础知识。放心,我会尽量简短明了!😉
3. 现代GPU与CUDA设备架构
3.1. GPU架构
让我们先比较一下CPU和GPU的架构。CPU由少数几个优化用于顺序处理的核心组成。相比之下,GPU拥有大规模并行架构,由数千个更小、更高效的核心组成,这些核心被设计用于同时处理多个任务。CPU和GPU都有不同大小和速度的内存层次结构。这在后面讨论FlashAttention如何实现巨大加速时非常重要。
图6:CPU与GPU架构对比。
这种通用GPU架构以及CUDA编程语言是由Nvidia在2006年的Tesla架构中引入的,其许多核心概念至今仍存在于现代GPU中。我强烈推荐阅读Nvidia的Tesla架构白皮书。
GPU架构围绕流式多处理器(SM)构建,每个SM包含多个CUDA和Tensor核心、共享内存和寄存器。每个SM可以并行执行多个线程块。同一个块内的线程可以通过共享内存相互通信,而不同块的线程则不能。
这里有一个小剧透:我们现在有如此多的线程,实际上是成千上万个,我们可以将繁重的计算任务(如矩阵乘法和卷积)并行化,将每个线程分配到一小部分数据上进行计算。
3.2. CUDA编程模型
Nvidia GPU使用CUDA编程模型进行编程,这是C++编程语言的扩展。它允许开发者编写CUDA内核,这些内核是由每个线程执行的小函数。这些线程被组织成块,块又被组织成网格。
图7:CUDA线程层次结构。
这种线程层次结构是CUDA的一个巧妙设计,因为它模仿了我们正在处理的数据的形状,即向量、矩阵和张量。此外,每个线程都知道自己在网格和块中的位置,这使我们能够计算出我们在处理的数据结构中的位置。
我们还可以将CUDA编程模型映射到GPU架构上,以了解数据存储的位置以及线程如何访问GPU内存(全局内存)中的数据。
图8:CUDA编程模型与GPU架构对比。
为了引入FlashAttention中的内核融合概念,我们需要了解CPU如何与GPU交互,以及数据是如何流动的。让我们从异构编程模型开始,这意味着我们使用不同的设备来执行代码,即CPU和GPU。
图9:使用CPU和GPU的异构编程模型。
CPU执行串行代码,并负责在主机和设备之间传输数据,以及启动内核。GPU则在设备的线程块网格中并行执行这些内核。结果被读回CPU,CPU继续其串行处理,并最终启动另一个内核。关键点在于,在朴素实现中,每个内核都需要在设备和主机之间传输数据。将其放在注意力机制的上下文中,这些顺序执行的内核可能是矩阵乘法、缩放、掩码、SoftMax和Dropout。
放大单个内核执行的数据流,我们可以看到数据在CPU和GPU之间流动的方式,以及线程如何访问GPU内存(全局内存)中的数据。
图10:单个CUDA内核执行的数据流。
不用说,上图被极度简化了,但它应该能给你一个大致的概念。关键点在于数据需要频繁传递,并且核心经常在等待数据加载。
好了,信息量很大。让我们通过一个例子来更清楚地说明这一点。最典型的例子是矩阵乘法,这是深度学习中的关键操作。
3.3. CUDA中的矩阵乘法
正如你所知,要计算两个矩阵A和B的乘积,我们需要取矩阵A的每一行与矩阵B的每一列的点积。在一个朴素(且极其低效)的实现中,我们可以通过为每个线程分配计算结果矩阵C的一个元素的任务来在CUDA中复制这个算法。
图11:矩阵乘法。
注意,计算结果矩阵C的整行需要矩阵A的一行和矩阵B的所有列。
同样,我们必须稍微绕道了解一下内核的执行方式以及不同内存带宽的情况。每次我们启动一个内核时,都会构建线程块网格,将内核加载到核心中,并将数据从全局内存加载到处理器的寄存器中。然后执行内核,并将结果存储回全局内存。一般来说,内存越大,离处理器越远,速度就越慢。因此,全局内存是层次结构中最大但最慢的内存,而单个核心的寄存器是最小但最快的内存。
为了让你有个概念(因为A100的1.5 TB/s听起来还不算太差),让我们看看我们需要加载多少数据来计算两个大小为NxN的矩阵A和B,其中N=4096。我们假设值是以float32格式存储的,这意味着每个元素占用4字节。
• 我们需要从矩阵A读取4096x4096个元素,从矩阵B读取4096x4096个元素:2x4096x4096x4字节=134MB。
• 我们需要在矩阵C中存储4096x4096个元素:4096x4096x4字节=64MB。
如果我们现在按照朴素实现的矩阵乘法,我们实际上会为矩阵A的每一行加载整个矩阵B。这意味着我们需要加载4096x4096个元素4096次,即134MB x 4096 = 549GB,这可是相当大的流量啊😅,而这仅仅是针对我们小上下文窗口N=4096个token的情况……
❓那么,知道这些后,有什么技巧可以让GPU上的计算更快、更高效呢?
3.4. 内核执行的优化
正如你可能已经猜到的,我们的主要关注点是减少对全局内存的读写次数,因为它是层次结构中最慢的内存,而我们的核心经常在等待数据加载。
有几种技术可以实现这一点,包括:
• 共享内存 —— 将频繁使用的数据放在核心附近。
• 块分块(Block Tiling) —— 重用已经加载的数据来计算多个输出。
• 内核融合 —— 避免将中间结果写入和读取全局内存。
注意,我们在这里会略过一些细节,以便简化内容,同时构建理解FlashAttention所需的基础知识。如果你对CUDA中的矩阵乘法及其优化有更详细的了解感兴趣,我推荐阅读Simon Boehm的这篇精彩且直观的博客文章:CUDA-MMM。
不过现在,让我们来谈谈前面提到的优化,以及它们如何应用于矩阵乘法。
共享内存
将频繁使用的数据放在核心附近。
我们要做的第一件事是利用层次结构中的不同内存级别。共享内存是一种小型、快速的内存,它在块内的线程之间共享(在单个流式多处理器中的所有核心之间)。正如FlashAttention论文中提到的,对于A100 GPU,全局内存的带宽为40GB,速度为1.5TB/s,而共享内存的带宽为20MB,速度为19TB/s。
我们知道,块内的线程可以通过共享内存相互通信,并且块可以在不同的流式多处理器上并行执行。
图12:矩阵乘法中共享内存的使用
因此,对于我们处理的每个块,我们将数据从全局内存加载到共享内存,然后块内的线程可以从共享内存访问数据。当处理下一个块时,我们重复这个过程。
块分块(Block Tiling)
重用已经加载的数据来计算多个输出。
接下来,我们要做的是在一个线程中计算结果矩阵C的多个元素。这称为分块,是一种常见的技术,用于减少对全局内存的读写次数。基本思想是将已经从共享内存加载到寄存器中的数据重用于多个计算。
图12:矩阵乘法中的块分块。
需要明确的是,我们仍然执行相同数量的计算,只是减少了从内存中读取数据的次数。这意味着我们花费更少的时间等待数据,而将更多时间用于计算。这提高了所谓的算术强度(Arithmetic Intensity),即操作次数与内存访问次数的比率。
内核融合(Kernel Fusion)
避免将中间结果写入和读取全局内存。
最后,我们可以将多个内核融合为一个内核,以避免将中间结果写入全局内存,仅仅为了在下一个内核中再次读取它们。
图13:内核融合如何影响运行时间和内存消耗。
重要提示:只有当内核操作的是同一块分块的数据时,才能进行内核融合。在讨论注意力机制时,这一点非常重要。
有了关于注意力机制、CUDA和GPU设备架构的所有这些知识,我们现在可以理解为什么注意力机制在现代GPU上如此缓慢和资源密集了!
4. 现代GPU上的注意力机制
4.1. 为什么注意力机制如此缓慢和资源密集?
正如你现在可能已经猜到的,我们需要频繁地将大量数据写入和读取全局内存。
让我们再看一下注意力机制及其涉及的矩阵形状,但这次是从GPU架构和CUDA编程模型的角度来分析。
图14:注意力层每个内核的读写操作。
我们可以得出以下几个观察结果:
- 我们有3个输入矩阵(Q、K和V),每个大小为Nxd,以及一个输出矩阵O,大小为Nxd。
- 注意力机制在中间结果中生成了巨大的NxN矩阵,这些矩阵都需要存储,并在执行下一个内核时再次读取。
- 像矩阵乘法和SoftMax这样的操作会重复使用其输入数据的相同部分来计算输出的单个元素。
- 像Dropout和掩码这样的操作是按元素应用的。
正如你所见,这里有很多昂贵的全局内存读写操作,而最糟糕的是,我们对这些中间结果并不感兴趣。
那么,让我们看看如何将前面提到的优化应用到注意力机制上。或者更确切地说,让我们看看为什么我们不能直接应用这些优化(在深入探讨FlashAttention如何解决这些问题之前)。
4.2. 我们如何改进注意力机制?
注意力机制的主要问题是,我们必须在中间结果中生成NxN矩阵,然后将其写入和读取全局内存。那么,为什么不简单地将相关的内核融合在一起,以避免写入和读取中间结果呢?
这是一个很好的想法,但有两个问题:
- 并非所有这些内核都操作同一块分块的数据,而这是内核融合所必需的(如前所述)。
- 我们需要在反向传播中保留中间结果以计算梯度。
稍微放慢一点速度,让我们看看第一个问题。SoftMax函数在实践中是如何实现的?让我们看一下我们的NxN矩阵。
图15:注意力机制中SoftMax的计算方式。
首先,注意SoftMax是按行应用的,而不是在整个矩阵上!其次,更重要的是,为了计算SoftMax矩阵P的单个元素,我们需要读取对应行的所有元素。
现在回想一下,矩阵乘法内核应用了块分块和滑动窗口,将输入加载到共享内存中以计算部分和。这与SoftMax函数不兼容,因为我们需要整行的所有元素来计算输出的单个元素。
将SoftMax函数按行应用的直觉是什么?NxN矩阵的每一行都与输入序列中的一个token相关。按行计算SoftMax会产生一个分数,表示序列中的其他token对该token的“关注”程度。换句话说,其他token对当前token完成任务的重要性。
假设我们可以将SoftMax函数应用于与矩阵乘法内核相同的块分块数据。这将允许我们将内核融合在一起,避免写入和读取中间结果。但我们仍然有第二个问题:我们需要在反向传播中保留中间结果以计算梯度。
为了训练我们的模型,我们通常会输入一批数据,计算与某个目标的损失,然后应用反向传播来更新模型的参数,以便下次输入数据时表现更好(希望如此)。通过应用链式法则,我们可以计算模型中每个操作的梯度。让我们从整个模型中提取一个简化的注意力机制:
图16:前向传播与反向传播的对比。
注意,前向传播的中间结果需要用来计算梯度。
因此,我们需要将中间结果保留在内存中,这意味着我们必须将其写入全局内存,并在反向传播中再次读取。这正是我们最初想要避免的……所以我们又回到了起点……
幸运的是,FlashAttention解决了这些问题。但现在,让我们结束第一部分,给自己一些时间消化这些信息。
这只是深入探讨FlashAttention及其背后的工程的开始。我非常想听听其他正在研究注意力机制或GPU优化的人——你们遇到了哪些挑战?有没有探索过什么见解或替代方法?请在评论区留言。
如果你觉得这个拆解有帮助,请点几个赞,帮助其他人发现它!敬请期待第二部分!
5. 总结
5.1. 总结
我们的目标:
• 让注意力机制更快、更高效。
❓为什么它会如此缓慢和资源密集?
• 我们需要在中间结果中生成NxN矩阵,并将其写入和读取全局内存。
❓为什么我们不能简单地将内核融合在一起以避免写入和读取中间结果?
• 并非所有这些内核都操作同一块分块的数据,这是内核融合所必需的。
• 我们需要在反向传播中保留中间结果以计算梯度。
5.2. 继续阅读第二部分
在第二部分中,我们将深入探讨FlashAttention,基于第一部分中获得的知识,逐步剖析前向传播和反向传播的算法细节。
6. 参考
论文
• FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness by Tri Dao et al. (2022)
• Nvidia Tesla: A Unified Graphics and Computing Architecture by Nvidia (2006)
博客文章
• The Illustrated Transformer by Jay Alammar.
• GPU Glossary by Modal.
• How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog by Simon Boehm.
• Matrix Multiplication On GPU: Part 2, Tiling by Lawrence Murray.
• Numerically Stable Softmax by Jay Mody.