大模型推理--FlashAttention

Attention机制可以算是Transformer的灵魂。正因为有了attention,模型的效果才能大幅提升。但同样是因为attention,导致transformer很难处理超长上下文,因为attention占用显存的大小与上下文长度的平方成正比,会导致上下文很长时显存爆炸。FlashAttention正是为了解决显存爆炸而设计的,它不光解决了显存爆炸的问题,同时也加速了attention的计算,并从数学上保证了结果的一致性。

1. Self Attention原理

Attention的计算涉及三个矩阵:Q、K、V,这三个矩阵在送给attention计算时都有相同的维度。我们先不考虑multi-head attention,只考虑one head的self attention。初始时,这三个矩阵的维度均为N x d,N即为上下文的长度(当前的大模型普遍支持的N上限为128K,谷歌也有大模型可以到1M)。通过下面的公式计算attention矩阵:
O = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V O=Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d})V O=Attention(Q,K,V)=softmax(d QKT)V
我们将attention的运算拆分,首先是 Q Q Q K T K^T KT的矩阵乘,它会生成一个N x N的矩阵,算法复杂度为 O ( N 2 d ) O(N^2d) O(N2d)。接着对N x N矩阵每一个元素进行缩放,并对每一行求softmax,这个过程不会改变矩阵的维度,算法复杂度为 O ( N 2 ) O(N^2) O(N2)。最后再和矩阵V相乘得到结果O,最终的矩阵维度N x d,和输入的三个矩阵的维度保持一致,算法复杂度同样为 O ( N 2 d ) O(N^2d) O(N2d),所以总的算法复杂度也为 O ( N 2 d ) O(N^2d) O(N2d)。attention运算会保持输入和输出矩阵维度的一致性,但是在具体的实现过程中,我们还是不可避免的要产生一个N x N的矩阵,这个矩阵在超长上下文下的显存占用非常可观。按照当前大模型普遍的上下文长度要求128K来算,N x N矩阵的大小为128Kx128K=16G,如果矩阵用fp16来存储,那单单这一个矩阵就要占32G显存,真的是显存占用不可限量!比较令人郁闷的是,我们最终计算的结果是N x d的矩阵,一般情况下d都要远远小于N(假设d为4K,最终的矩阵大小也仅为512M),但是为了保证结果的正确性我们不得不生成一个N x N的大矩阵。这也是为什么早期的大模型支持的上下文长度普遍较短的原因–attention占用的显存太大。

2. Multi-head Attention原理

在真正使用attention的时候,我们往往采用multi-head attention。Multi-head attention的计算公式和self attention基本一致,它改变了 Q 、 K 、 V Q、K、V QKV每一行的定义:将维度d的向量分成h组变成一个 h ∗ d k h * d_k hdk的矩阵, Q 、 K 、 V Q、K、V QKV此时成为了 N ∗ h ∗ d k N * h * d_k Nhdk的三维矩阵(不考虑batch维)。分别将 Q 、 K 、 V Q、K、V QKV的第一和第二维进行转置得到三个维度为 h ∗ N ∗ d k h * N * d_k hNdk的三维矩阵。此时的三个矩阵就是具有h个头的 Q 、 K 、 V Q、K、V QKV,我们就可以按照self attention的定义计算h个头的attention值。算法复杂度为 O ( h N 2 d k ) = O ( N 3 ) O(hN^2d_k)=O(N^3) O(hN2dk)=O(N3)。与self attention相比,multi-head attention的算法复杂度由 O ( N 2 d ) O(N^2d) O(N2d)变为 O ( N 3 ) O(N^3) O(N3),在长上下文的情况下计算量增大了很多。不仅计算量增大很多,在中间运算的过程中会产生h个N x N的矩阵的矩阵,所以显存占用也是大了h倍。这也说明如果不优化显存占用,multi-head attention的上下文长度无法扩展到很大。

3. FlashAttention原理

正当大家通过购买大显存显卡或者采用稀疏注意力来增大上下文长度时,22年斯坦福的一个博士吹岛(Tri Dao)发了一篇论文《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》,FlashAttention应运而生。FlashAttention不光解决了multi-head attention在计算过程中显存占用过大的问题,将 O ( N 2 ) O(N^2) O(N2)级的显存占用优化到了 O ( N d ) O(Nd) O(Nd),而且它的计算结果还是精确的不是近似的,更重要的是attention的计算速度也变快了很多,加快了大模型推理prefilling阶段的速度,使得TTFT(time to first token)的时间大幅减小。不过请大家注意,FlashAttention解决了Attention计算过程中显存占用过大的问题,但是在整个大模型推理中KV Cache也会占用大量显存,也是需要优化的点,我们会在后续的博客中进行介绍。

FlashAttention解决显存占用过大采用的方法就是分块(Tiling),将完整 Q K T QK^T QKT的计算分成一个一个小块来实现。因为块足够小,小到可以直接在共享内存(shared memory)中放下,从而加快了显存的访问,进而也加快了计算速度。分块不是特别新鲜的技术,矩阵乘的加速就利用了分块的技巧,但是self attention的计算不止 Q K T QK^T QKT,还包括对该矩阵求softmax。FlashAttention最大的创新点就来自于分块计算的同时还保证了softmax计算的正确性。我们知道,对一个向量求softmax,我们必须得获得完整的向量才可以,显然分块破坏了这个前提。理论上,没有获得完整的向量前我们是无法计算softmax的,但是我们可以在分块的过程中迭代计算softmax,原理也不是那么复杂,下面给出解释。

3.1 softmax的分块计算

首先,在实际计算softmax的时候,为了避免指数运算的数值溢出,往往会利用safe softmax求softmax,这两个是等价的,只不过多了一步求max的过程。给定一个向量 X = [ x 1 , x 2 , . . . , x d ] X=[x_1,x_2,...,x_d] X=[x1,x2,...,xd],我们先求得向量的最大值
m ( x ) = max ⁡ i ( x i ) \begin{equation} m(x)=\max_{i} (x_i) \end{equation} m(x)=imax(xi)
,有了最大值之后我们就可以利用公式:
s o f t m a x ( x i ) = e x i − m ( x ) ∑ j = 1 d e x j − m ( x ) \begin{equation} softmax(x_i)=\frac {e^{x_i-m(x)}} {\sum_{j=1}^d e^{x_j-m(x)}} \end{equation} softmax(xi)=j=1dexjm(x)exim(x)
求softmax。每一项减最大值再求exp就可以保证每一项的结果都在(0,1]之间,从而避免了exp过大导致溢出。
现在我们将向量 X X X分成 T c T_c Tc段,每段长度 B c B_c Bc,每一段分别去求safe softmax。从公式(2)可以看出,如果不做特殊处理,每一段计算的softmax值肯定是错的,因为softmax的计算要依赖一个全局的max值,但是每一段获取的只是一个局部max值,并且分母求和也只考虑了当前段的和没有考虑所有段的和。所以为了获得正确的softmax值,我们就需要让分子和分母都按照正确的定义去计算。FlashAttention通过在迭代计算每一段的safe softmax过程中维护三个变量 m , f , ℓ m,f,ℓ mf使得迭代结束之后就可以获得正确的softmax值。

m m m的含义是每一段向量的最大值,也即
m i ( X ) = max ⁡ i ∗ B c ≤ j < ( i + 1 ) ∗ B c ( x j ) \begin{equation} m_i(X)=\max_{i*B_c \le j \lt (i+1)*B_c} (x_j) \end{equation} mi(X)=iBcj<(i+1)Bcmax(xj)
m的长度为 T c T_c Tc。在实际的代码中,FlashAttention还会维护一个到当前段为止的全局最大值,我们记为 M , M M,M MM的更新公式为
M i = max ⁡ ( M i − 1 , m i ) \begin{equation} M_i=\max (M_{i-1},m_i) \end{equation} Mi=max(Mi1,mi)
f f f的含义是每一段每一个元素利用当前段计算的最大值求 e x − m e^{x-m} exm,也即
f i j ( X ) = e x j − m i ( 0 < i < T c , i ∗ B c ≤ j < ( i + 1 ) ∗ B c ) \begin{equation} f_{ij}(X)= e^{x_j-m_i} (0<i<T_c, i*B_c \le j <(i+1)*B_c) \end{equation} fij(X)=exjmi(0<i<Tc,iBcj<(i+1)Bc)
上述公式的含义是第i段中,每一个元素都减去段最大值 m i m_i mi并求指数, f f f的长度为d。正如前面所说,这样的计算方法是错误的,我们需要进行修正。假设我们当前遍历到第i段,这i段的 f f f值已经求出,根据公式(5)我们知道每一段的 f f f值减的都是段内局部最大值 m i m_i mi,正确的计算应该是减去前i段的全局最大值 M i M_i Mi,我们可以根据变量 m m m M M M来更新每一段的 f f f,方法就是每一段的 f f f值乘以 e m i − M i e^{m_i-M_i} emiMi,然后利用指数运算的公式我们可以得到:
f i j ( X ) ∗ e m i − M i = e x j − m i ∗ e m i − M i = e x j − m i + m i − M i = e x j − M i \begin{equation} f_{ij}(X) *e^{m_i-M_i}=e^{x_j-m_i}*e^{m_i-M_i}=e^{x_j-m_i+m_i-M_i}=e^{x_j-M_i} \end{equation} fij(X)emiMi=exjmiemiMi=exjmi+miMi=exjMi
如此我们就得到遍历完i段时正确的 f f f值。当然,每一段我们都需要更新前面所有的 f f f值,总的复杂度为 O ( d 2 ) O(d^2) O(d2)。但是 f f f的计算更多是概念性的,在真实的代码中是不可能按照如此高的复杂度去计算所有的 f f f,原理在于我们最终计算完softmax之后还要和 V V V相乘,我们在迭代计算的过程中进行更新即可,此时复杂度为 O ( d ) O(d) O(d)

从上面的介绍可以看出,为了获得safe softmax中正确的分子值,我们可以在分段计算的过程中,将之前的结果乘以 e m i − M i e^{m_i-M_i} emiMi进行矫正,所以 e m i − M i e^{m_i-M_i} emiMi可以看做是矫正因子。同理,分母值的更新也可以通过矫正因子实现。
ℓ ℓ 表示每一段 f f f的累加值。也即
ℓ i ( X ) = ∑ j = i ∗ B c ( i + 1 ) ∗ B c − 1 f i j ( X ) \begin{equation} ℓ_{i}(X)= \sum_{j=i*B_c}^{(i+1)*B_c-1}f_{ij}(X) \end{equation} i(X)=j=iBc(i+1)Bc1fij(X)
ℓ ℓ 的长度为 T c T_c Tc。和 f f f的计算类似,我们也可以通过将每一个 ℓ ℓ 乘以 e m i − M i e^{m_i-M_i} emiMi将求和的结果进行矫正。对 ℓ ℓ 求和就可以获得到当前段的所有元素求exp的和,也即safe softmax的正确分母值,我们记为 L L L。在更新 L L L的时候我们按照下述公式
L i = e M i − 1 − M i ∗ L i − 1 + e m i − M i ∗ ℓ i \begin{equation} L_{i}=e^{M_{i-1}-M_i}*L_{i-1}+e^{m_i-M_i}*ℓ_{i} \end{equation} Li=eMi1MiLi1+emiMii
迭代完成,复杂度为 O ( T c ) O(T_c) O(Tc) L i L_i Li表示到第i段为止 L L L的总和,它包含两部分,前i-1段的和加上第i段的和,两段分别乘以各自的矫正因子即可得到正确结果,FlashAttention中就利用了公式(8)来计算safe softmax的分母。如果大家不理解公式(8)的正确性,可以考虑 L L L只有三段的情况,就很容易搞明白了。 L 0 = 0 , L 1 = ℓ 1 , L 2 = e M 1 − M 2 ∗ L 1 + e m 2 − M 2 ∗ ℓ 2 L_0=0,L_1=ℓ_1,L_{2}=e^{M_{1}-M_2}*L_{1}+e^{m_2-M_2}*ℓ_{2} L0=0,L1=1,L2=eM1M2L1+em2M22,表示 L 1 L_1 L1在计算过程中用的是第一段的最大值 M 1 M_1 M1,现在需要将其替换成 M 2 M_2 M2,所以我们乘以了 e M 1 − M 2 e^{M_{1}-M_2} eM1M2,而 ℓ 2 ℓ_{2} 2根据定义利用的是第二段的段内最大值计算的,现在也需要替换成 M 2 M_2 M2,所以我们乘以了 e m 2 − M 2 e^{m_2-M_2} em2M2。最终的表现就是 L 2 L_2 L2的计算从前两段求和来看每一项都是 e x − M 2 e^{x-M_2} exM2。类似的, L 3 = e M 2 − M 3 ∗ L 2 + e m 3 − M 3 ∗ ℓ 3 L_{3}=e^{M_{2}-M_3}*L_{2}+e^{m_3-M_3}*ℓ_{3} L3=eM2M3L2+em3M33,由于 L 2 L_2 L2每一项都是 e x − M 2 e^{x-M_2} exM2,所以我们将其乘以 e M 2 − M 3 e^{M_{2}-M_3} eM2M3把前两段都变成 e x − M 3 e^{x-M_3} exM3,第三段 ℓ 3 ℓ_{3} 3自身的矫正和第二段 ℓ 2 ℓ_{2} 2自身的矫正一样,如此我们就完成了 L 3 L_3 L3的计算。这里再强调一点, ℓ ℓ 的计算可以通过 M M M来实现,如此就可以避免对 ℓ ℓ 进行矫正,但是FlashAttention V1版本就是按照公式(8)实现的,所以我们也按照公式(8)来描述,在FlashAttention V2中对此进行了优化,可以加快推理速度。

上面的解释我们都是按照向量来的,在实际实现时向量会变成矩阵,每一行完成的操作都是一样的,分段softmax也就变成了分块softmax。

3.2 Attention的分块计算

有了分块softmax的经验,我们再看看如何分块计算Attention。基本上就是利用我们上面提到的 f f f Q 、 K 、 V Q、K、V QKV中的 V V V分块相乘,我们也只考虑 f f f V V V是向量的情况,矩阵只是多了很多次重复操作而已。
行向量 f f f与列向量 V V V相乘的结果是一个标量,我们记为 O O O。在分段情况下,如果不考虑正确性我们可以得到 O O O的计算公式为:
O i = O i − 1 + ∑ j = i ∗ B c ( i + 1 ) ∗ B c − 1 f i j ∗ V i j \begin{equation} O_{i}=O_{i-1}+\sum_{j=i*B_c}^{(i+1)*B_c-1}f_{ij}*V_{ij} \end{equation} Oi=Oi1+j=iBc(i+1)Bc1fijVij
公式(9)有两个错误:第一就是softmax计算公式中要除以 L L L,上面的公式没有考虑;第二就是公式的前半部分和后半部分都没有考虑矫正因子,我们分别修正这两个错误,获得正确的计算公式。首先考虑除以 L L L,可以得到如下公式:
O i = ( L i − 1 ∗ O i − 1 + f i ∗ V i ) / L i \begin{equation} O_{i}=(L_{i-1}*O_{i-1}+f_{i}*V_{i})/L_{i} \end{equation} Oi=(Li1Oi1+fiVi)/Li
如果我们能获得整个向量,则 f f f V V V相乘之后我们需要除以最终的 L L L才能获得正确的结果,当计算的第i段的时候也是一样,我们需要除以到第i段为止的 L L L值。同时,在迭代的过程中我们还要修正 O i − 1 O_{i-1} Oi1的值,因为 O i − 1 O_{i-1} Oi1除以了 L i − 1 L_{i-1} Li1,所以我们再乘以 L i − 1 L_{i-1} Li1即可。公式(10)可以看做对Attention计算过程中的分母进行了矫正,我们也要对分子进行矫正。分子的矫正类似于公式(8)计算 L L L,前i-1段每一项都是利用 e x − M i − 1 e^{x-M_{i-1}} exMi1进行计算,为了将其变为 e x − M i e^{x-M_{i}} exMi,我们需要乘以 e M i − 1 − M i e^{M_{i-1}-M_{i}} eMi1Mi,而第i段自身我们可以通过乘以 e m i − M i e^{m_{i}-M_{i}} emiMi进行矫正,如此得到最终计算Attention的公式:
O i = ( L i − 1 ∗ e M i − 1 − M i ∗ O i − 1 + e m i − M i ∗ f i ∗ V i ) / L i \begin{equation} O_{i}=(L_{i-1}*e^{M_{i-1}-M_{i}}*O_{i-1}+e^{m_{i}-M_{i}}*f_{i}*V_{i})/L_{i} \end{equation} Oi=(Li1eMi1MiOi1+emiMifiVi)/Li
有了这个公式你再去看FlashAttention的伪代码就没有任何难度了,此处按照向量来介绍,真实场景就是把 f f f V V V变成了块矩阵。

3.3 FlashAttention算法流程

介绍完分块softmax和单行Attention计算之后,我们就可以看一下FlashAttention算法的具体流程,如图1所示,该图截自FlashAttention的论文。首先强调一点,FlashAttention V1版本实现并不是最优的,初次看会觉得很多地方写得很别扭,也正是因为这个原因才会有FlashAttention V2和V3版本,我们只需要通过伪代码理解FlashAttention的原理即可。如果从CUDA编程角度来看,下面的伪代码实际上是单指一个block内的线程完成的计算,grid会从batch和attention head数层面对 Q 、 K 、 V Q、K、V QKV进行划分,也即伪代码实现了block内head=1的attention计算,但是block内每个线程如何划分数据进行attention计算伪代码中并未说明。
在这里插入图片描述

图1 FlashAttention算法伪代码

前面提到,FlashAttention是通过分块的方式优化attention计算的,所以代码的主体就是两层for循环遍历 Q 、 K 、 V Q、K、V QKV的不同块。外层for循环对 K 、 V K、V KV进行分块,每块大小 B c B_c Bc,总共有 T c T_c Tc块;内层循环对 Q Q Q和输出 O O O进行分块,每块大小 B r B_r Br,总共有 T r T_r Tr块。块大小的设计就是要保证 Q 、 K 、 V 、 O Q、K、V、O QKVO这四个变量的块都能放在共享内存里,伪代码第1行就是根据共享内存的大小来确定块大小。

最内层for循环执行核心的分块attention计算。伪代码第9行的 S S S保存了 Q Q Q K T K^T KT分块矩阵乘的结果,不知道大家看到这个地方的时候有没有疑问:伪代码中并未具体提到 S S S的访问方式,共享内存已被 Q 、 K 、 V 、 O Q、K、V、O QKVO这四个变量打满,所以我理解 S S S只能是放在寄存器中。 S S S的大小为 B r ∗ B c , B c B_r * B_c,B_c BrBcBc的大小没有限制,如果 B r B_r Br也不做限制,则在d较小的时候 S S S的大小就可能会非常大导致寄存器不够用,所以在伪代码中对 B r B_r Br进行了min的限制,这样 S S S的大小最大为 ⌈ M 4 ⌉ \lceil \frac {M} {4} \rceil 4M,而往往GPU SM中的寄存器大小要大于共享内存的大小(H 100的32位寄存器个数为64K,最大可设置共享内存为228K,可以参考Table 21 Technical Specifications per Compute Capability),所以就不会存在寄存器不够用的问题。虽然从寄存器规模上可以存储整个 S S S,但是在真实进行CUDA编程的时候,寄存器是属于每个线程独享的,所以 S S S的计算需要仔细设计,每个线程完成 S S S中部分数据的计算,这又会涉及如何对 S S S分块使得矩阵乘速度最快的问题,我们不关心如此细节,只关心FlashAttention整体的流程。还得说明一点,从论文中给出的加载流程(图二)可以看出, S S S的计算利用了SRAM,在非共享内存的情况下只可能是L1 Cache,但是具体怎么利用的我也没有搞清楚。
在这里插入图片描述

图2 FlashAttention分块加载流程
伪代码第10行用来计算三个变量: m ~ , P ~ , ℓ ~ \widetilde{m},\widetilde{P},\widetilde{ℓ} m ,P , ,分别对应上面提到的 m 、 f 和 ℓ m、f和ℓ mf,表示当前块每行的最大值、指数值、指数和。伪代码第11行用来求全局最大值 m n e w m^{new} mnew与全局求和 ℓ n e w ℓ^{new} new,分别对应上面提到的 M M M L L L。最难理解的是伪代码第12行,不过通过我们3.2节的分析理解第12行不再困难。从向量到矩阵还有一个变化,就是对分母的矫正利用了对角阵,这个在CUDA中有快速实现。

4. FlashAttention-2原理

在前面的介绍中我们提到了一些FlashAttention V1版本存在的问题,V2中对其进行了改进(论文为:FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning),我们着重介绍一下这些改进点。

4.1 改进1:循环次序的交换

一眼看去,图3和图1最大的区别就是两层for循环的次序发生了变化。V1版本中先遍历 K 、 V K、V KV再遍历 Q 、 O Q、O QO,在V2中先遍历 Q 、 O Q、O QO再遍历 K 、 V K、V KV。内层循环遍历 K 、 V K、V KV就可以一口气完成 O O O的计算,如此我们就可以避免 O O O与显存进行交换,进一步减少访存时间。
在这里插入图片描述

图3 FlashAttention2算法伪代码

4.2 改进2:迭代公式的优化

在3.1节我们提到,对公式(8)可以将当前段 ℓ ℓ 的矫正换成当前为止的最大值 M M M,从而避免乘以 e m i − M i e^{m_{i}-M_{i}} emiMi,如此可以将公式(8)优化为:
L i = e M i − 1 − M i ∗ L i − 1 + ℓ i \begin{equation} L_{i}=e^{M_{i-1}-M_i}*L_{i-1}+ℓ_{i} \end{equation} Li=eMi1MiLi1+i
同步地, f f f的定义也发生了变化, f f f表示每一段每一个元素利用全局最大值求 e x − M i e^{x-M_i} exMi,而不再是当前段的最大值 m i m_i mi。定义不变,还是表示每一段 f f f的累加值,不过由于 f f f利用了到当前段位置的全局最大值,所以 ℓ ℓ 计算的累加值也是正确的。
另一个优化点是公式(11)求O的过程中,我们每次都对分母进行矫正,但是从完整softmax的角度考虑,我们只需要在获得完整向量计算的 L L L之后再除以 L L L即可,不用迭代更新,如此就可以将公式(11)优化为:
O i = e M i − 1 − M i ∗ O i − 1 + f i ∗ V i \begin{equation} O_{i}=e^{M_{i-1}-M_{i}}*O_{i-1}+f_{i}*V_{i} \end{equation} Oi=eMi1MiOi1+fiVi
如伪代码中12行所示,内层循环结束之后, O O O再除以 L L L即是正确的Attention值。

4.3 改进3:对Q分块优化SM利用率

在3.3节一开始我们提到,V1对block的划分是基于batch_size和num_heads来做的,当上下文较短时,batch和heads个数会比较大,通常情况下会把GPU的SM打满。但是当上下文较长时,batch和heads就会较小,此时如果再按照batch和heads划分block就会存在SM空闲的问题,所以V2中又把 Q Q Q的分块作为grid的一个维度来划分block,这样就可以进一步提升grid的尺寸,从而可以使SM充分被利用起来。当然,在这种长上下文情况下,FlashAttention V2的伪代码也会发生变化,最外层的for循环就不存在了,只有内层for循环。相当于代码需要根据上下文长度进行区分,较短时就执行图3的伪代码,较长时就把外层循环去掉。

5.总结

24年7月份,FlashAttention-3也出来了,性能比V2又提升了1.5~2倍,主要是应用了现代GPU的特性,目前只支持H100,后续也会在更多卡上进行落地。我个人有一种感觉,V3应该不是终点,Attention计算还有可以挖掘的点,期待V4、V5的出现……

6.引用

  1. 图解大模型计算加速系列:FlashAttention V1,从硬件到计算逻辑
  2. 图解大模型计算加速系列:Flash Attention V2,从原理到并行计算
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值