FlashAttention v2相比于v1有哪些更新?

1. FlashAttention v1

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

FlashAttention不需要在全局内存上实现 X 和 A 矩阵,而是将上述公式中的整个计算融合到单个 CUDA 内核中。这要求我们设计一种算法来仔细管理片上内存(on-chip memory)(如流算法),因为 NVIDIA GPU 的共享内存(SRAM)很小。

对于矩阵乘法等经典算法,使用平铺(tiling)来确保片上内存不超过硬件限制。这种平铺方法是有效的原因是:加法是关联的,允许将整个矩阵乘法分解为许多平铺矩阵乘法的总和。

请添加图片描述

2. FlashAttention v2的更新

请添加图片描述

  • 减少了non-matmul FLOPs的数量(消除了原先频繁rescale)。虽然non-matmul FLOPs仅占总FLOPs的一小部分,但它们的执行时间较长,这是因为GPU有专用的矩阵乘法计算单元,其吞吐量高达非矩阵乘法吞吐量的16倍。因此,减少non-matmul FLOPs并尽可能多地执行matmul FLOPs非常重要。

  • 提出了在序列长度维度上并行化。该方法在输入序列很长(此时batch size通常很小)的情况下增加了GPU利用率。即使对于单个head,也在不同的thread block之间进行并行计算。

  • 在一个attention计算块内,将工作分配在一个thread block的不同warp上,以减少通信和共享内存读/写。

3. 第一个更新详解

减少了non-matmul FLOPs的数量(消除了原先频繁rescale)。虽然non-matmul FLOPs仅占总FLOPs的一小部分,但它们的执行时间较长,这是因为GPU有专用的矩阵乘法计算单元,其吞吐量高达非矩阵乘法吞吐量的16倍。因此,减少non-matmul FLOPs并尽可能多地执行matmul FLOPs非常重要。

non-matmul vs. GEMM

non matrix multiply

非矩阵乘法:指的是在矩阵乘法之外的操作,如加法、乘法、除法等, 例如rescale操作,使用的是GPU的通用硬件:CUDA Core.

General Matrix Multiply

通用矩阵乘法:指的是矩阵乘法的一种实现方式,使用的是GPU的专用硬件:Tensor Core.

CUDA Core vs. Tensor Core

非矩阵乘法操作的执行时间较长,因为GPU有专用的矩阵乘法计算单元Tensor Core,其吞吐量高达非矩阵乘法吞吐量的16倍。
下图是A100的介绍,可以看到CUDA Core单精度TF16的计算能力是19.5TFLOPS,而Tensor Core的计算能力是312TFLOPS,相差16倍之多。
请添加图片描述

Rescale操作

请添加图片描述

V2版本调整了算法,减少了非矩阵乘法操作的浮点运算次数,同时保持输出不变。在原始的FlashAttention(即V1版本)中,每个block的每次迭代都需要执行rescale操作,这涉及到除法运算。

而在V2中,这种rescale操作被延后到循环的最后才执行一次,每次计算可以减少一次除法运算。这样的调整是因为只要每次迭代确保分子部分被scale为正确值以及分母部分计算正确即可。这种优化减少了计算量,提高了效率

4. 第二个更新详解

提出了在序列长度维度上并行化。该方法在输入序列很长(此时batch size通常很小)的情况下增加了GPU利用率。即使对于单个head,也在不同的thread block之间进行并行计算。

GPU硬件

  • Streaming Processor(SP):是最基本的处理单元,从fermi架构开始被叫做CUDA core。
  • Streaming MultiProcessor(SM):一个SM由多个CUDA core(SP)组成,每个SM在不同GPU架构上有不同数量的CUDA core,例如Pascal架构中一个SM有128个CUDA core。

GPU软件

  • thread:一个CUDA并行程序由多个thread来执行,thread是最基本的执行单元(the basic unit of execution);
  • warp:一个warp通常包含32个thread。每个warp中的thread可以同时执行相同的指令,从而实现SIMT(单指令多线程)并行。warp是SM中最小的调度单位(the smallest scheduling unit on an SM),一个SM可以同时处理多个warp;
  • thread block:一个thread block可以包含多个warp,同一个block中的thread可以同步,也可以通过shared memory进行通信。thread block是GPU执行的最小单位(the smallest unit of execution on the GPU)。一个warp中的threads必然在同一个block中,如果block所含thread数量不是warp大小的整数倍,那么多出的那个warp中会剩余一些inactive的thread。也就是说,即使warp的thread数量不足,硬件也会为warp凑足thread,只不过这些thread是inactive状态,但也会消耗SM资源。

V2 相对于 V1 的第二个主要更新是增加了序列长度维度的并行计算,这样做的目的是提高 GPU 的 SM(Streaming Multiprocessor)利用率,尤其是在处理长序列数据时。在 V1 中,计算是先按批次和头数并行执行,然后在序列长度上串行执行。这意味着当序列长度较长时,可能无法充分利用所有可用的 SM,因为每个 block 只能处理序列的一个片段。

在 V2 中,通过在序列长度维度上进行并行化,可以更有效地分配计算任务到更多的 block,从而更充分地利用 GPU 资源。具体来说,V2 通过增加 num_m_block 的概念,将 Q 矩阵在序列长度方向上进一步划分为多个小块,每一块由不同的 block 来处理。这样,每个 block 可以独立地计算它所负责的输出部分,减少了不同 block 之间的依赖和通信开销。

这边有点像continous batching的思路

5. 第三个更新详解

请添加图片描述

在V2中,通过调整循环顺序,将Q作为外循环,K和V作为内循环,每个线程块(thread block)负责计算输出矩阵O的一部分。这种设计允许每个线程块独立进行计算,减少了线程块之间的依赖和通信需求。同时,V2版本在前向传播中进一步减少了非矩阵乘法操作的浮点运算,以充分利用GPU上的专用计算单元,如Nvidia GPU上的Tensor Cores,从而最大化GPU的吞吐量。
此外,V2版本在反向传播中也进行了优化,采用了类似的分块策略来优化计算和内存访问,提高效率和性能。通过这种方式,FlashAttention V2能够实现更高的并行性,减少不必要的计算和内存访问,从而提升整体的计算性能。

参考

[1] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

[2] FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

[3] FlashAttention2详解(性能比FlashAttention提升200%

[4] GitHub: LLMForEverybody

请添加图片描述

  • 23
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值