[NeurIPS 2022] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Introduction

  • 作者指出 attention 的性能瓶颈主要在访存而非计算量上,由此提出 FlashAttention,通过算子融合将 attention 操作融合为单个算子,大大降低了访存量,极大地提升了 Transformer 模型的训练速度
  • 目前 Huggingface Transformers 库已经支持了 FlashAttention

Method

Preliminaries - GPU Memory Hierarchy.
GPU 的存储器层次化结构可以大致分为 SRAM/HBM/DRAM,其中 HBM 就是 GPU 的显存,SRAM 是片上缓存;kernel 在运算时需要将输入从 HBM 中读入到 SRAM,计算完毕后再写回 SRAM. 对于 memory-bound 场景,如何降低 HBM 与 SRAM 之间的访存开销是关键
在这里插入图片描述

From Online Softmax to FlashAttention

  • self-attention 的计算过程如下所示,
    在这里插入图片描述在这里插入图片描述FlashAttention 的惊人之处在于计算时根本不需要在 global memory 里实例化 X X X A A A,而是直接把上述 3 个操作融进 1 个 CUDA kernel,根本不需要存储和加载 attention matrix (which can be very large when the sequence length or batchsize is large),从而减少了访存和存储开销
    在这里插入图片描述
  • 自注意力层算子融合的难点在于如何进行 tiling. 我们想要每次加载 Q , K , V , O Q,K,V,O Q,K,V,O 的一小块然后计算一部分的自注意力层输出,但由于 softmax 在计算时需要用到全部的 logits,因此简单的 tiling 策略无法得到正确的计算结果;对此,FlashAttention 使用了 online softmax 的思想

Preliminaries - Tiling.
由于 on-chip memory 空间有限,在计算矩阵乘 C = A B C=AB C=AB 时我们无法同时把 A , B , C A,B,C A,B,C 都加载进来,通常的做法是使用 tiling 每次加载 A , B , C A,B,C A,B,C 的一小块,每次都只计算一部分矩阵乘的结果;例如下图中将 A , B , C A,B,C A,B,C 都分成了 T × T T\times T T×T 大小的小块,每次就只需要将 3 个 T × T T\times T T×T 的小块从 global memory 加载进 on-chip memory
在这里插入图片描述

(Safe) Softmax

  • 如下所示,常规计算 softmax 需要 3 次循环;对于自注意力计算,由于 SRAM 没法一次性存下 pre-softmax logits X i = Q i K T X_i=Q_iK^T Xi=QiKT,因此就需要对 X i X_i Xi 访存 3 次 (not I/O efficient);如果要进一步降低访存次数,就需要减少循环次数

在这里插入图片描述

Online Softmax

  • online softmax 利用动态规划的思想,只需 2 次循环即可完成 softmax 的计算,具体来说,可以在循环时计算 d i ′ d_i' di,在循环结束时,有 d N = d N ′ d_N=d_N' dN=dN
    在这里插入图片描述这样我们就可以将 m i m_i mi d i ′ d_i' di 的计算融到 1 次循环里,即

在这里插入图片描述

FlashAttention

  • 将 online softmax 运用到 attention 计算上,可以得到需要 2 次循环的 attention 算法

在这里插入图片描述


  • 如果需要输出 softmax 的计算结果,那么 2 次循环已经是最优,但我们最终只需要输出 o N o_N oN 而不需要输出 a i a_i ai,因此可以利用 online softmax 的思想继续优化,将 2 次循环缩减为 1 次循环;首先可以写出如下递推式,满足 o N = o N ′ o_N=o_N' oN=oN
    在这里插入图片描述这样就可以在 1 个循环内得到 attention 的计算结果,即

在这里插入图片描述


  • tiling 版本的算法如下所示,把 sequence length N N N 划分为若干个大小为 b b b 的 tiles,每个 tile 顺序计算,对 SRAM 的需求量仅取决于 block size b b b (head dim 一般比较小,例如 LLaMA 系列为 128,因此无需在 head dim 上分块)

在这里插入图片描述

FlashAttention

  • FlashAttention 完整的算法流程如下,相比上述的 tiling 版本,会同时计算 O O O 的多行的结果 (9 ~ 12 行就是上述 tiling 算法的多行批处理版本,不仅对 K , V K,V K,V 做了分块,对 Q Q Q 也做了分块),每次输出 O i ∈ R B r × d O_{i}\in\R^{B_r\times d} OiRBr×d;block size B c , B r B_c,B_r Bc,Br 的选取应该是确保 tiling 所需存储空间小于 SRAM 大小,不过在 FlashAttention-2 中,block size 改为了从 { 64 , 128 } × { 64 , 128 } \{64,128\}\times\{64,128\} {64,128}×{64,128} 中手动选取最合适的
  • 相比标准的 attention 计算方法,FlashAttention 的计算复杂度不变,仍然为 O ( N 2 d ) O(N^2d) O(N2d) (9 行和 12 行的矩阵乘计算复杂度为 O ( B c B r d ) O(B_cB_rd) O(BcBrd),循环内需要执行 O ( T c T r ) = O ( N 2 B c B r ) O(T_cT_r)=O(\frac{N^2}{B_cB_r}) O(TcTr)=O(BcBrN2) 次);此外,FlashAttention 还需要额外 O ( N ) O(N) O(N) 的空间存储 l , m l,m l,m,相比之下,标准的 attetion 实现则需要额外 O ( N 2 ) O(N^2) O(N2) 的空间去存储 attetion matrix

在这里插入图片描述在这里插入图片描述

Backward – Recomputation.
为了避免 attention matrix O ( N 2 ) O(N^2) O(N2) 的存储开销,作者在前向传播时并不会存储 X , A X,A X,A,而是只保存 attetion 输出 O O O 以及 softmax normalization statistics ( m , ℓ m,\ell m,),从而在反向传播时,将 Q , K , V Q,K,V Q,K,V 加载进 SRAM 后重计算得到 X , A X,A X,A 以进行反向传播;即使增加了计算量,但由于访存量的降低,FlashAttetion 的训练速度仍然远快于标准的 attention 实现 (反向传播过程详见论文 Appendix B)

Analysis: IO Complexity of FlashAttention

  • K , V K,V K,V 的每个 block 都被加载进内存 1 次,访存次数为 O ( N d ) O(Nd) O(Nd) (6 行); Q , O Q,O Q,O 的每个 block 都被读/写 T c T_c Tc 次,访存次数为 O ( T c N d ) = O ( N 2 d 2 M ) O(T_cNd)=O(\frac{N^2d^2}{M}) O(TcNd)=O(MN2d2);所以总的访存次数 O ( N 2 d 2 M ) O(\frac{N^2d^2}{M}) O(MN2d2);相比之下,标准的 attention 实现访存次数为 O ( N d + N 2 ) O(Nd+N^2) O(Nd+N2) (读/写 Q , K , V Q,K,V Q,K,V 和 attetion matrix)
    在这里插入图片描述
  • For typical values of d d d (64-128) and M M M (around 100KB), d 2 M ≪ 1 \frac{d^2}{M}\ll1 Md21. 因此 FlashAttention 的访存次数远小于标准的 attetion 实现,并且作者还证明了 FlashAttention 的访存复杂度是 exact attention algorithm 中最优的

Experiments

  • Training Speed.
    在这里插入图片描述在这里插入图片描述在这里插入图片描述
  • Benchmarking Attention.
    在这里插入图片描述

References

  • 24
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值