FlashAttention
Attention 的计算复杂度
Attention 的计算过程
- 线性变换:对输入序列 X ∈ R N × k X\in \mathbb{R}^{N\times k} X∈RN×k 进行线性变换,分别乘以三个不同的权重矩阵 W Q W^Q WQ, W K W^K WK, W V W^V WV ∈ R k × d \in \mathbb{R}^{k\times d} ∈Rk×d 得到 Q Q Q、 K K K、 V V V ∈ R N × d \in \mathbb{R}^{N\times d} ∈RN×d。,则该步骤的复杂度为 O ( 3 N ∗ k ∗ d ) O(3N*k*d) O(3N∗k∗d)。
- 计算相似度得分:通过 Q Q Q、 K K K 两个矩阵计算相似度得分,得到注意力权重矩阵 ∈ R N × N \in \mathbb{R}^{N\times N} ∈RN×N,时间复杂度为 O ( N ∗ N ∗ d ) O(N* N*d) O(N∗N∗d)。
- softmax: 时间复杂度为: O ( N ∗ N ) O(N* N) O(N∗N)
- 加权平均:将 softmax 之后的注意力权重矩阵与 V V V 矩阵相乘。该步骤的复杂度为 O ( N ∗ N ∗ d ) O(N*N * d) O(N∗N∗d)。
- 总体的计算复杂度约为 O ( N ∗ N ∗ d ) O(N* N*d) O(N∗N∗d)
Structure of GPU Memory
在GPU 当中,memory 也跟CPU memory 一样分成不同的level,通常越上层空间越小但是速度越快,
大家平常主要提到的GPU memory 通常是指high bandwidth memory (HBM),以A100 来说, HBM 大概有40 GB ~ 80 GB 左右,且HBM 的bandwidth 为1.5–2.0TB/s。
再往上一层的memory 称为SRAM,总容量大概有192 KB * 108 (streaming multiproecssors) 左右,虽然大小少很多,但是他的bandwidth 可以达到19TB/s,因此当有运算需要从HBM 当中不断读写资料的时候,这样的速度差就容易导致HBM 的读取变成整体效能的bottleneck。
在GPU 当中有非常大量的threads (kernel) 负责执行operation 的运算,而整个运算的过程基本上是从HBM 当中将资料载入至SRAM 中,执行运算并将output 存回HBM 当中。
根据每个operation 实际运算时间和memory 存取的时间多寡,我们可以将operations 归纳为两个类别,分别是compute-bound以及memory-bound
Compute-bound的意思为运算的主要时间都耗费在operation 的计算上,HBM 的存取只占了其中一点点的时间,例如多维度的矩阵相乘或是高channel 数的convolution都属于这类。
Memory-bound的意思为运算主要时间都耗费在memory 的读取上,而实际的运算只占了其中一点点的时间,例如elementwise (eg, activation, dropout) 和reduction (eg, sum, softmax, batch norm, layer norm) 皆属于memory-bound。
Tiling and Recomputation
FlashAttention 就是通过利用 GPU 硬件中的特殊设计,针对全局内存和共享存储的 I/O 速度的不同,尽可能
- 高效地使用 SRAM 来加快计算速度
- 避免 HBM 中读取或写入注意力矩阵
达成该目标需要能 做到在不访问整个输入的情况下计算 Softmax 函数,并且后向传播中不能存储中间注意力矩阵。
标准 Attention 算法中,Softmax 计算按行进行,即在与 $V $ 做矩阵乘法之前,需要将 Q Q Q、 K K K 的各个分块完成一整行的计算。在得到 Softmax 的结果后,再与矩阵 V V V 分块做矩阵乘。
而在 FlashAttention 中,将输入分割成块,并在输入块上进行多次传递,从而 以增量方式 执行 Softmax 计算。
- Tiling (在向前和向后传递时使用)-基本上将 N × N N \times N N×N softmax/scores 矩阵分块成块。
- Recomputation (重算,仅在向后传递中使用)
Tiling(平铺),其核心思想是将原始的注意力矩阵分解成更小的子矩阵,然后分别对这些子矩阵进行计算,只要这个子矩阵的大小可以在 SRAM 内存放,那么即可在计算过程中只访问 SRAM 了。
Theorem 1
Flash Attention 的算法能够在 O ( N 2 d ) O(N^2d) O(N2d) FLOPs 内正确的返回 O = s o f t m a x ( Q K T ) V O=\mathrm{softmax}(QK^T)V O=softmax(QKT)V,除去输入和输出之外还需要 O ( N ) O(N) O(N) 大小的空间。
输入: Q , K , V Q,K,V Q,K,V $\in \mathbb{R}^{N\times d} $ 位于高速显存(HBM)中,GPU 芯片中的 SRAM 大小 为 M M M
输出: O O O
-
设置块大小 block size
-
block 的列大小: B c = [ M 4 d ] B_c =[\frac{M}{4d}] Bc=[4dM]
-
block 的行大小: B r = min ( [ M 4 d ] , d ) B_r = \min \big([\frac{M}{4d}], d \big) Br=min([4dM],d)
min \min min 函数的目的是防止块大小 B c × B r > M 4 B_c \times B_r > \frac{M}{4} Bc×Br>4M, 这样就无法把 4 个这样的块放到 SRAM。
-
-
HBM 中初始化
- O = ( 0 ) N × d ∈ R N × d O = (0)_{N \times d} \in \mathbb{R}^{N\times d} O=(0)N×d∈RN×d:结果矩阵 O O O 初始化为零,后面会逐步把 中间结果 累加进去
- l = ( 0 ) N ∈ R N l = (0)_{N}\in \mathbb{R}^{N} l=(0)N∈RN :对于每一行来说,它是一个标量,用于 累加指数和,由于输出有 N N N 行,所以这里的 l l l 是长度为 N N N 的向量
- m = ( − ∞ ) N ∈ R N m= (- \infty)_N \in \mathbb{R}^{N} m=(−∞)N∈RN: m m m 用于记录 每一行当前最大的值,所以也是长度为 N N N,而 − ∞ -\infty −∞ 是求 max \max max 的合适初始值。
-
切分子块
-
将矩阵 Q Q Q 按行切分成 T r = [ N B r ] T_r = [\frac{N}{B_r}] Tr=[BrN] 块 Q 1 , … , Q T r Q_1, \dots , Q_{T_r} Q1,…,QTr, Q i ∈ R B r × d Q_i\in \mathbb{R}^{B_r\times d} Qi∈RBr×d
Q i = ( Q i 1 1 Q i 1 2 … Q i 1 d Q i 2 1 Q i 2 2 … Q i 2 d ⋮ ⋮ ⋱ ⋮ Q i B r 1 Q i B r 2 … Q i B r d ) Q_i =\begin{pmatrix} Q_{i_11} & Q_{i_12} & \ldots & Q_{i_1d} \\ Q_{i_21} & Q_{i_22} & \ldots & Q_{i_2d}\\ \vdots & \vdots & \ddots &\vdots\\ Q_{i_{B_r}1} & Q_{i_{B_r}2} & \ldots & Q_{i_{B_r}d}\\ \end{pmatrix} Qi= Qi11Qi21⋮QiBr1Qi12Qi22⋮QiBr2……⋱…Qi1dQi2d⋮QiBrd - 将矩阵 K K K 按行切分成 T c = [ N B c ] T_c = [\frac{N}{B_c}] Tc=[BcN] 块 K 1 , … , K T c K_1, \dots , K_{T_c} K1,…,KTc,
-