FlashAttention-2 是如何实现更快的计算速度的

FlashAttention-2是对原始FlashAttention算法的一系列改进,旨在优化在GPU上的计算性能。本节详细讨论了FlashAttention-2的算法、并行性以及工作分区策略。

算法

FlashAttention-2的关键优化点在于减少非矩阵乘法(matmul)的浮点运算,以充分利用GPU上的专用计算单元(如Nvidia GPU上的Tensor Cores),这些单元在处理matmul操作(尤其是在FP16/BF16格式下)时性能显著优化。该优化的目标是通过尽可能多地执行matmul操作来最大化GPU的吞吐量。

前向传播
  1. 在线Softmax技巧:FlashAttention-2对在线Softmax计算进行了修改,以最小化非matmul浮点操作:

    • 避免通过 diag(ℓ(2))^-1 重新缩放输出更新的两个项。
    • 维持一个“未缩放”的O(2)版本,并保留统计信息 ℓ(2)。
    • 仅在循环结束时,通过 diag(ℓ(last))^-1 缩放最终的O(last)以获得正确的输出。
  2. 最大化matmul FLOPs:为了最大化GPU的性能,FlashAttention-2重点优化了matmul操作,因为现代GPU上的专用单元(如Tensor Cores)在这些操作上表现出色。以Nvidia A100 GPU为例,其FP16/BF16 matmul的理论吞吐量可以达到312 TFLOPs/s,而非matmul FP32的吞吐量仅为19.5 TFLOPs/s。因此,FlashAttention-2通过优化算法,尽可能地减少非matmul操作,从而保持高吞吐量的执行效率。

  3. 算法细节:FlashAttention-2的前向传播通过以下步骤实现:

    • 将输入矩阵Q、K、V分成大小为𝐵𝑟 × 𝑑的𝑇𝑟块,将输出矩阵O和logsumexp𝐿也相应地分块。
    • 在每个线程块内部分配工作以最大化GPU资源的利用。
    • 引入了在线Softmax技巧,通过有效管理和缩放中间结果,减少了不必要的计算开销。

反向传播

FlashAttention-2的反向传播与FlashAttention类似,但也有一些微调:

  • 仅使用逐行logsumexp 𝐿,而不是softmax中的最大值和指数和。
  • 使用类似的分块策略来优化计算和内存访问,以提高反向传播的效率和性能。

FlashAttention-2在并行性和工作分区方面进行了深入优化,以在GPU上实现更高的计算效率和性能。本节详细讨论了FlashAttention-2的并行化策略和工作分区方法。

并行性

前向传播

在FlashAttention-2中,前向传播的并行化策略如下:

  1. 线程块调度:每个注意力头使用一个线程块来处理,总共有batch size × number of heads个线程块。每个线程块被调度到一个流多处理器(SM)上执行。例如,Nvidia A100 GPU上有108个这样的SM。这种调度在大量线程块(如≥ 80)时非常高效,因为可以充分利用GPU的计算资源。

  2. 对长序列的优化:对于长序列(通常意味着较小的batch size或较少的头数),为了更好地利用GPU上的多处理器,FlashAttention-2额外并行化了序列长度维度。这在这种情况下显著提高了性能和效率。

反向传播

在反向传播中,为了避免在不同列块之间的共享计算,FlashAttention-2采用了类似的并行化策略:

  • 线程块调度:每个列块使用一个线程块来处理。通过使用原子加操作来在不同线程块之间进行通信,以更新dQ,从而避免了共享内存的读写冲突。

工作分区

前向传播

在前向传播中,FlashAttention-2改进了工作分区策略,避免了FlashAttention中的"split-K"方案,具体包括:

  • K和V的分割:FlashAttention-2将Q分割到4个线程束(warp)中,同时使得K和V对所有线程束可访问。每个线程束执行矩阵乘法以获取QK>的一部分,并将其与V的一部分相乘,从而获得对应输出的片段。这种改进减少了线程束之间的通信,降低了共享内存的读写次数,从而提升了性能。
反向传播

在反向传播中,为了避免"split-K"方案带来的同步问题,FlashAttention-2选择了适当的线程束分区策略,以优化计算和内存访问效率。

3.1.1 前向传播公式解析

在前向传播中,我们使用在线Softmax技巧,并进行了两处微调来减少非矩阵乘法的浮点运算。

1. 不对中间结果进行缩放
原始的公式是:
O ( 2 ) = diag ( ℓ ( 1 ) ℓ ( 2 ) ) − 1 O ( 1 ) + diag ( ℓ ( 2 ) ) − 1 e S ( 2 ) − m ( 2 ) V ( 2 ) O^{(2)} = \text{diag}\left(\frac{\ell^{(1)}}{\ell^{(2)}}\right)^{-1} O^{(1)} + \text{diag}(\ell^{(2)})^{-1} e^{S^{(2)} - m^{(2)}} V^{(2)} O(2)=diag((2)(1))1O(1)+diag((2))1eS(2)m(2)V(2)

这里, ( O ( 2 ) ) (O^{(2)}) (O(2)) 表示第二块的输出, ( ℓ ( 1 ) ) (\ell^{(1)}) ((1)) ( ℓ ( 2 ) ) (\ell^{(2)}) ((2))分别是第一块和第二块的行和。我们先对第一项 ( O ( 1 ) ) (O^{(1)}) (O(1)) 进行缩放,再加上经过软max计算的第二项。改进后,简化为:
O ~ ( 2 ) = diag ( ℓ ( 1 ) ) − 1 O ( 1 ) + e S ( 2 ) − m ( 2 ) V ( 2 ) \tilde{O}^{(2)} = \text{diag}(\ell^{(1)})^{-1} O^{(1)} + e^{S^{(2)} - m^{(2)}} V^{(2)} O~(2)=diag((1))1O(1)+eS(2)m(2)V(2)

我们不再缩放第二项,只保留未缩放的中间结果 ( O ~ ( 2 ) ) (\tilde{O}^{(2)}) (O~(2))。最后一步,对最终结果进行缩放:
O ( 2 ) = diag ( ℓ ( last ) ) − 1 O ~ ( last ) O^{(2)} = \text{diag}(\ell^{(\text{last})})^{-1} \tilde{O}^{(\text{last})} O(2)=diag((last))1O~(last)

2. 使用logsumexp简化存储
我们只需存储logsumexp:
L ( j ) = m ( j ) + log ⁡ ( ℓ ( j ) ) L^{(j)} = m^{(j)} + \log(\ell^{(j)}) L(j)=m(j)+log((j))

这样做简化了存储,减少了计算中需要保存的变量数量。

具体计算过程
假设我们有两个块:

  1. 第一个块:
    m ( 1 ) = rowmax ( S ( 1 ) ) m^{(1)} = \text{rowmax}(S^{(1)}) m(1)=rowmax(S(1))
    计算第一个块的行最大值,得到 ( m ( 1 ) ) (m^{(1)}) (m(1))

ℓ ( 1 ) = rowsum ( e S ( 1 ) − m ( 1 ) ) \ell^{(1)} = \text{rowsum}(e^{S^{(1)} - m^{(1)}}) (1)=rowsum(eS(1)m(1))
计算第一个块的行和,得到 ( ℓ ( 1 ) ) (\ell^{(1)}) ((1))

O ~ ( 1 ) = e S ( 1 ) − m ( 1 ) V ( 1 ) \tilde{O}^{(1)} = e^{S^{(1)} - m^{(1)}} V^{(1)} O~(1)=eS(1)m(1)V(1)
计算未缩放的第一个块的输出。

  1. 第二个块:
    m ( 2 ) = max ⁡ ( m ( 1 ) , rowmax ( S ( 2 ) ) ) = m m^{(2)} = \max(m^{(1)}, \text{rowmax}(S^{(2)})) = m m(2)=max(m(1),rowmax(S(2)))=m
    计算第二个块的行最大值,并与第一个块的最大值进行比较,得到新的最大值 ( m ) (m) (m)

ℓ ( 2 ) = e m ( 1 ) − m ( 2 ) ℓ ( 1 ) + rowsum ( e S ( 2 ) − m ( 2 ) ) = rowsum ( e S ( 1 ) − m ) + rowsum ( e S ( 2 ) − m ) \ell^{(2)} = e^{m^{(1)} - m^{(2)}} \ell^{(1)} + \text{rowsum}(e^{S^{(2)} - m^{(2)}}) = \text{rowsum}(e^{S^{(1)} - m}) + \text{rowsum}(e^{S^{(2)} - m}) (2)=em(1)m(2)(1)+rowsum(eS(2)m(2))=rowsum(eS(1)m)+rowsum(eS(2)m)
计算第二个块的行和,同时考虑第一个块的贡献。

P ~ ( 2 ) = diag ( ℓ ( 2 ) ) − 1 e S ( 2 ) − m ( 2 ) \tilde{P}^{(2)} = \text{diag}(\ell^{(2)})^{-1} e^{S^{(2)} - m^{(2)}} P~(2)=diag((2))1eS(2)m(2)
计算缩放后的软max概率。

O ~ ( 2 ) = diag ( e m ( 1 ) − m ( 2 ) ) O ~ ( 1 ) + e S ( 2 ) − m ( 2 ) V ( 2 ) = e S ( 1 ) − m V ( 1 ) + e S ( 2 ) − m V ( 2 ) \tilde{O}^{(2)} = \text{diag}(e^{m^{(1)} - m^{(2)}}) \tilde{O}^{(1)} + e^{S^{(2)} - m^{(2)}} V^{(2)} = e^{S^{(1)} - m} V^{(1)} + e^{S^{(2)} - m} V^{(2)} O~(2)=diag(em(1)m(2))O~(1)+eS(2)m(2)V(2)=eS(1)mV(1)+eS(2)mV(2)
结合前两个块的结果,得到未缩放的第二个块的输出。

O ( 2 ) = diag ( ℓ ( 2 ) ) − 1 O ~ ( 2 ) = O O^{(2)} = \text{diag}(\ell^{(2)})^{-1} \tilde{O}^{(2)} = O O(2)=diag((2))1O~(2)=O
最终对未缩放的输出进行缩放,得到最终结果 ( O ) (O) (O)

3.1.2 反向传播公式解析

在反向传播中,我们也进行了简化,只使用logsumexp ( L ) ( L ) (L)

反向传播的具体计算步骤

  1. 初始化:
    d Q = ( 0 ) N × d dQ = (0)_{N \times d} dQ=(0)N×d
    初始化 dQ 为零矩阵。

  2. 计算梯度:
    D = rowsum ( d O ∘ O ) D = \text{rowsum}(dO \circ O) D=rowsum(dOO)
    计算 dO 和 O 的点积的行和。

每个块的处理:
S ( j ) = Q i K j T S^{(j)} = Q_i K_j^T S(j)=QiKjT
计算块 ( i ) (i) (i) ( j ) (j) (j) 的乘积 ( S ( j ) ) (S^{(j)}) (S(j))

P ( j ) = e S ( j ) − L i P^{(j)} = e^{S^{(j)} - L_i} P(j)=eS(j)Li
计算 ( P ( j ) ) (P^{(j)}) (P(j)),即 ( S ( j ) ) (S^{(j)}) (S(j)) 减去 ( L i ) (L_i) (Li) 后的指数。

d V j ← d V j + ( P ( j ) ) T d O i dV_j \leftarrow dV_j + (P^{(j)})^T dO_i dVjdVj+(P(j))TdOi
更新 dV_j。

d P ( j ) = d O i V j T dP^{(j)} = dO_i V_j^T dP(j)=dOiVjT
计算 dP^{(j)}。

d S ( j ) = P ( j ) ∘ ( d P ( j ) − D i ) dS^{(j)} = P^{(j)} \circ (dP^{(j)} - D_i) dS(j)=P(j)(dP(j)Di)
计算 dS^{(j)}。

d Q i ← d Q i + d S ( j ) K j dQ_i \leftarrow dQ_i + dS^{(j)} K_j dQidQi+dS(j)Kj
更新 dQ_i。

d K j ← d K j + ( d S ( j ) ) T Q i dK_j \leftarrow dK_j + (dS^{(j)})^T Q_i dKjdKj+(dS(j))TQi
更新 dK_j。

通过这些步骤,我们完成了反向传播的计算,得到dQ、dK和dV。

多查询注意力和分组查询注意力

在这些变体中,我们将多个头的查询集中到同一个头的键和值上,以减少KV缓存的大小。在反向传播中,我们需要跨不同头对梯度dK和dV求和。

  • 8
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值