FlashAttention

FlashAttention

Attention 的计算复杂度

Attention 的计算过程

  • 线性变换:对输入序列 X ∈ R N × k X\in \mathbb{R}^{N\times k} XRN×k 进行线性变换,分别乘以三个不同的权重矩阵 W Q W^Q WQ W K W^K WK W V W^V WV ∈ R k × d \in \mathbb{R}^{k\times d} Rk×d 得到 Q Q Q K K K V V V ∈ R N × d \in \mathbb{R}^{N\times d} RN×d。,则该步骤的复杂度为 O ( 3 N ∗ k ∗ d ) O(3N*k*d) O(3Nkd)
  • 计算相似度得分:通过 Q Q Q K K K 两个矩阵计算相似度得分,得到注意力权重矩阵 ∈ R N × N \in \mathbb{R}^{N\times N} RN×N,时间复杂度为 O ( N ∗ N ∗ d ) O(N* N*d) O(NNd)
  • softmax: 时间复杂度为: O ( N ∗ N ) O(N* N) O(NN)
  • 加权平均:将 softmax 之后的注意力权重矩阵与 V V V 矩阵相乘。该步骤的复杂度为 O ( N ∗ N ∗ d ) O(N*N * d) O(NNd)
  • 总体的计算复杂度约为 O ( N ∗ N ∗ d ) O(N* N*d) O(NNd)

Structure of GPU Memory

image-20240709170708194

在GPU 当中,memory 也跟CPU memory 一样分成不同的level,通常越上层空间越小但是速度越快

大家平常主要提到的GPU memory 通常是指high bandwidth memory (HBM),以A100 来说, HBM 大概有40 GB ~ 80 GB 左右,且HBM 的bandwidth 为1.5–2.0TB/s。

再往上一层的memory 称为SRAM,总容量大概有192 KB * 108 (streaming multiproecssors) 左右,虽然大小少很多,但是他的bandwidth 可以达到19TB/s,因此当有运算需要从HBM 当中不断读写资料的时候,这样的速度差就容易导致HBM 的读取变成整体效能的bottleneck。

在GPU 当中有非常大量的threads (kernel) 负责执行operation 的运算,而整个运算的过程基本上是从HBM 当中将资料载入至SRAM 中,执行运算并将output 存回HBM 当中。

根据每个operation 实际运算时间和memory 存取的时间多寡,我们可以将operations 归纳为两个类别,分别是compute-bound以及memory-bound

  • Compute-bound的意思为运算的主要时间都耗费在operation 的计算上,HBM 的存取只占了其中一点点的时间,例如多维度的矩阵相乘或是高channel 数的convolution都属于这类。

  • Memory-bound的意思为运算主要时间都耗费在memory 的读取上,而实际的运算只占了其中一点点的时间,例如elementwise (eg, activation, dropout) 和reduction (eg, sum, softmax, batch norm, layer norm) 皆属于memory-bound。


Tiling and Recomputation

FlashAttention 就是通过利用 GPU 硬件中的特殊设计,针对全局内存和共享存储的 I/O 速度的不同,尽可能

  1. 高效地使用 SRAM 来加快计算速度
  2. 避免 HBM 中读取或写入注意力矩阵

达成该目标需要能 做到在不访问整个输入的情况下计算 Softmax 函数,并且后向传播中不能存储中间注意力矩阵

标准 Attention 算法中,Softmax 计算按行进行,即在与 $V $ 做矩阵乘法之前,需要将 Q Q Q K K K 的各个分块完成一整行的计算。在得到 Softmax 的结果后,再与矩阵 V V V 分块做矩阵乘。

而在 FlashAttention 中,将输入分割成块,并在输入块上进行多次传递,从而 以增量方式 执行 Softmax 计算。

  1. Tiling (在向前和向后传递时使用)-基本上将 N × N N \times N N×N softmax/scores 矩阵分块成块。
  2. Recomputation (重算,仅在向后传递中使用)

Tiling(平铺),其核心思想是将原始的注意力矩阵分解成更小的子矩阵,然后分别对这些子矩阵进行计算,只要这个子矩阵的大小可以在 SRAM 内存放,那么即可在计算过程中只访问 SRAM 了。

Theorem 1

Flash Attention 的算法能够在 O ( N 2 d ) O(N^2d) O(N2d) FLOPs 内正确的返回 O = s o f t m a x ( Q K T ) V O=\mathrm{softmax}(QK^T)V O=softmax(QKT)V,除去输入和输出之外还需要 O ( N ) O(N) O(N) 大小的空间。

  • 输入: Q , K , V Q,K,V Q,K,V $\in \mathbb{R}^{N\times d} $ 位于高速显存(HBM)中,GPU 芯片中的 SRAM 大小 M M M

  • 输出: O O O

  1. 设置块大小 block size

    • block 的列大小: B c = [ M 4 d ] B_c =[\frac{M}{4d}] Bc=[4dM]

    • block 的行大小: B r = min ⁡ ( [ M 4 d ] , d ) B_r = \min \big([\frac{M}{4d}], d \big) Br=min([4dM],d)

      min ⁡ \min min 函数的目的是防止块大小 B c × B r > M 4 B_c \times B_r > \frac{M}{4} Bc×Br>4M, 这样就无法把 4 个这样的块放到 SRAM。

  2. HBM 中初始化

    • O = ( 0 ) N × d ∈ R N × d O = (0)_{N \times d} \in \mathbb{R}^{N\times d} O=(0)N×dRN×d:结果矩阵 O O O 初始化为零,后面会逐步把 中间结果 累加进去
    • l = ( 0 ) N ∈ R N l = (0)_{N}\in \mathbb{R}^{N} l=(0)NRN :对于每一行来说,它是一个标量,用于 累加指数和,由于输出有 N N N 行,所以这里的 l l l 是长度为 N N N 的向量
    • m = ( − ∞ ) N ∈ R N m= (- \infty)_N \in \mathbb{R}^{N} m=()NRN m m m 用于记录 每一行当前最大的值,所以也是长度为 N N N,而 − ∞ -\infty 是求 max ⁡ \max max 的合适初始值。
  3. 切分子块

    • 将矩阵 Q Q Q 按行切分成 T r = [ N B r ] T_r = [\frac{N}{B_r}] Tr=[BrN] Q 1 , … , Q T r Q_1, \dots , Q_{T_r} Q1,,QTr Q i ∈ R B r × d Q_i\in \mathbb{R}^{B_r\times d} QiRBr×d
      Q i = ( Q i 1 1 Q i 1 2 … Q i 1 d Q i 2 1 Q i 2 2 … Q i 2 d ⋮ ⋮ ⋱ ⋮ Q i B r 1 Q i B r 2 … Q i B r d ) Q_i =\begin{pmatrix} Q_{i_11} & Q_{i_12} & \ldots & Q_{i_1d} \\ Q_{i_21} & Q_{i_22} & \ldots & Q_{i_2d}\\ \vdots & \vdots & \ddots &\vdots\\ Q_{i_{B_r}1} & Q_{i_{B_r}2} & \ldots & Q_{i_{B_r}d}\\ \end{pmatrix} Qi= Qi11Qi21QiBr1Qi12Qi22QiBr2Qi1dQi2dQiBrd

      • 将矩阵 K K K 按行切分成 T c = [ N B c ] T_c = [\frac{N}{B_c}] Tc=[BcN] K 1 , … , K T c K_1, \dots , K_{T_c} K1,,KTc
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值