FlashAttention-2是对原始FlashAttention算法的一系列改进,旨在优化在GPU上的计算性能。本节详细讨论了FlashAttention-2的算法、并行性以及工作分区策略。
算法
FlashAttention-2的关键优化点在于减少非矩阵乘法(matmul)的浮点运算,以充分利用GPU上的专用计算单元(如Nvidia GPU上的Tensor Cores),这些单元在处理matmul操作(尤其是在FP16/BF16格式下)时性能显著优化。该优化的目标是通过尽可能多地执行matmul操作来最大化GPU的吞吐量。
前向传播
-
在线Softmax技巧:FlashAttention-2对在线Softmax计算进行了修改,以最小化非matmul浮点操作:
- 避免通过
diag(ℓ(2))^-1
重新缩放输出更新的两个项。 - 维持一个“未缩放”的O(2)版本,并保留统计信息 ℓ(2)。
- 仅在循环结束时,通过
diag(ℓ(last))^-1
缩放最终的O(last)以获得正确的输出。
- 避免通过
-
最大化matmul FLOPs:为了最大化GPU的性能,FlashAttention-2重点优化了matmul操作,因为现代GPU上的专用单元(如Tensor Cores)在这些操作上表现出色。以Nvidia A100 GPU为例,其FP16/BF16 matmul的理论吞吐量可以达到312 TFLOPs/s,而非matmul FP32的吞吐量仅为19.5 TFLOPs/s。因此,FlashAttention-2通过优化算法,尽可能地减少非matmul操作,从而保持高吞吐量的执行效率。
-
算法细节:FlashAttention-2的前向传播通过以下步骤实现:
- 将输入矩阵Q、K、V分成大小为𝐵𝑟 × 𝑑的𝑇𝑟块,将输出矩阵O和logsumexp𝐿也相应地分块。
- 在每个线程块内部分配工作以最大化GPU资源的利用。
- 引入了在线Softmax技巧,通过有效管理和缩放中间结果,减少了不必要的计算开销。
反向传播
FlashAttention-2的反向传播与FlashAttention类似,但也有一些微调:
- 仅使用逐行logsumexp 𝐿,而不是softmax中的最大值和指数和。
- 使用类似的分块策略来优化计算和内存访问,以提高反向传播的效率和性能。
FlashAttention-2在并行性和工作分区方面进行了深入优化,以在GPU上实现更高的计算效率和性能。本节详细讨论了FlashAttention-2的并行化策略和工作分区方法。
并行性
前向传播
在FlashAttention-2中,前向传播的并行化策略如下:
-
线程块调度:每个注意力头使用一个线程块来处理,总共有batch size × number of heads个线程块。每个线程块被调度到一个流多处理器(SM)上执行。例如,Nvidia A100 GPU上有108个这样的SM。这种调度在大量线程块(如≥ 80)时非常高效,因为可以充分利用GPU的计算资源。
-
对长序列的优化:对于长序列(通常意味着较小的batch size或较少的头数),为了更好地利用GPU上的多处理器,FlashAttention-2额外并行化了序列长度维度。这在这种情况下显著提高了性能和效率。
反向传播
在反向传播中,为了避免在不同列块之间的共享计算,FlashAttention-2采用了类似的并行化策略:
- 线程块调度:每个列块使用一个线程块来处理。通过使用原子加操作来在不同线程块之间进行通信,以更新dQ,从而避免了共享内存的读写冲突。
工作分区
前向传播
在前向传播中,FlashAttention-2改进了工作分区策略,避免了FlashAttention中的"split-K"方案,具体包括:
- K和V的分割:FlashAttention-2将Q分割到4个线程束(warp)中,同时使得K和V对所有线程束可访问。每个线程束执行矩阵乘法以获取QK>的一部分,并将其与V的一部分相乘,从而获得对应输出的片段。这种改进减少了线程束之间的通信,降低了共享内存的读写次数,从而提升了性能。
反向传播
在反向传播中,为了避免"split-K"方案带来的同步问题,FlashAttention-2选择了适当的线程束分区策略,以优化计算和内存访问效率。
3.1.1 前向传播公式解析
在前向传播中,我们使用在线Softmax技巧,并进行了两处微调来减少非矩阵乘法的浮点运算。
1. 不对中间结果进行缩放
原始的公式是:
O
(
2
)
=
diag
(
ℓ
(
1
)
ℓ
(
2
)
)
−
1
O
(
1
)
+
diag
(
ℓ
(
2
)
)
−
1
e
S
(
2
)
−
m
(
2
)
V
(
2
)
O^{(2)} = \text{diag}\left(\frac{\ell^{(1)}}{\ell^{(2)}}\right)^{-1} O^{(1)} + \text{diag}(\ell^{(2)})^{-1} e^{S^{(2)} - m^{(2)}} V^{(2)}
O(2)=diag(ℓ(2)ℓ(1))−1O(1)+diag(ℓ(2))−1eS(2)−m(2)V(2)
这里,
(
O
(
2
)
)
(O^{(2)})
(O(2)) 表示第二块的输出,
(
ℓ
(
1
)
)
(\ell^{(1)})
(ℓ(1))和
(
ℓ
(
2
)
)
(\ell^{(2)})
(ℓ(2))分别是第一块和第二块的行和。我们先对第一项
(
O
(
1
)
)
(O^{(1)})
(O(1)) 进行缩放,再加上经过软max计算的第二项。改进后,简化为:
O
~
(
2
)
=
diag
(
ℓ
(
1
)
)
−
1
O
(
1
)
+
e
S
(
2
)
−
m
(
2
)
V
(
2
)
\tilde{O}^{(2)} = \text{diag}(\ell^{(1)})^{-1} O^{(1)} + e^{S^{(2)} - m^{(2)}} V^{(2)}
O~(2)=diag(ℓ(1))−1O(1)+eS(2)−m(2)V(2)
我们不再缩放第二项,只保留未缩放的中间结果
(
O
~
(
2
)
)
(\tilde{O}^{(2)})
(O~(2))。最后一步,对最终结果进行缩放:
O
(
2
)
=
diag
(
ℓ
(
last
)
)
−
1
O
~
(
last
)
O^{(2)} = \text{diag}(\ell^{(\text{last})})^{-1} \tilde{O}^{(\text{last})}
O(2)=diag(ℓ(last))−1O~(last)
2. 使用logsumexp简化存储
我们只需存储logsumexp:
L
(
j
)
=
m
(
j
)
+
log
(
ℓ
(
j
)
)
L^{(j)} = m^{(j)} + \log(\ell^{(j)})
L(j)=m(j)+log(ℓ(j))
这样做简化了存储,减少了计算中需要保存的变量数量。
具体计算过程
假设我们有两个块:
- 第一个块:
m ( 1 ) = rowmax ( S ( 1 ) ) m^{(1)} = \text{rowmax}(S^{(1)}) m(1)=rowmax(S(1))
计算第一个块的行最大值,得到 ( m ( 1 ) ) (m^{(1)}) (m(1))。
ℓ
(
1
)
=
rowsum
(
e
S
(
1
)
−
m
(
1
)
)
\ell^{(1)} = \text{rowsum}(e^{S^{(1)} - m^{(1)}})
ℓ(1)=rowsum(eS(1)−m(1))
计算第一个块的行和,得到
(
ℓ
(
1
)
)
(\ell^{(1)})
(ℓ(1))。
O
~
(
1
)
=
e
S
(
1
)
−
m
(
1
)
V
(
1
)
\tilde{O}^{(1)} = e^{S^{(1)} - m^{(1)}} V^{(1)}
O~(1)=eS(1)−m(1)V(1)
计算未缩放的第一个块的输出。
- 第二个块:
m ( 2 ) = max ( m ( 1 ) , rowmax ( S ( 2 ) ) ) = m m^{(2)} = \max(m^{(1)}, \text{rowmax}(S^{(2)})) = m m(2)=max(m(1),rowmax(S(2)))=m
计算第二个块的行最大值,并与第一个块的最大值进行比较,得到新的最大值 ( m ) (m) (m)。
ℓ
(
2
)
=
e
m
(
1
)
−
m
(
2
)
ℓ
(
1
)
+
rowsum
(
e
S
(
2
)
−
m
(
2
)
)
=
rowsum
(
e
S
(
1
)
−
m
)
+
rowsum
(
e
S
(
2
)
−
m
)
\ell^{(2)} = e^{m^{(1)} - m^{(2)}} \ell^{(1)} + \text{rowsum}(e^{S^{(2)} - m^{(2)}}) = \text{rowsum}(e^{S^{(1)} - m}) + \text{rowsum}(e^{S^{(2)} - m})
ℓ(2)=em(1)−m(2)ℓ(1)+rowsum(eS(2)−m(2))=rowsum(eS(1)−m)+rowsum(eS(2)−m)
计算第二个块的行和,同时考虑第一个块的贡献。
P
~
(
2
)
=
diag
(
ℓ
(
2
)
)
−
1
e
S
(
2
)
−
m
(
2
)
\tilde{P}^{(2)} = \text{diag}(\ell^{(2)})^{-1} e^{S^{(2)} - m^{(2)}}
P~(2)=diag(ℓ(2))−1eS(2)−m(2)
计算缩放后的软max概率。
O
~
(
2
)
=
diag
(
e
m
(
1
)
−
m
(
2
)
)
O
~
(
1
)
+
e
S
(
2
)
−
m
(
2
)
V
(
2
)
=
e
S
(
1
)
−
m
V
(
1
)
+
e
S
(
2
)
−
m
V
(
2
)
\tilde{O}^{(2)} = \text{diag}(e^{m^{(1)} - m^{(2)}}) \tilde{O}^{(1)} + e^{S^{(2)} - m^{(2)}} V^{(2)} = e^{S^{(1)} - m} V^{(1)} + e^{S^{(2)} - m} V^{(2)}
O~(2)=diag(em(1)−m(2))O~(1)+eS(2)−m(2)V(2)=eS(1)−mV(1)+eS(2)−mV(2)
结合前两个块的结果,得到未缩放的第二个块的输出。
O
(
2
)
=
diag
(
ℓ
(
2
)
)
−
1
O
~
(
2
)
=
O
O^{(2)} = \text{diag}(\ell^{(2)})^{-1} \tilde{O}^{(2)} = O
O(2)=diag(ℓ(2))−1O~(2)=O
最终对未缩放的输出进行缩放,得到最终结果
(
O
)
(O)
(O)。
3.1.2 反向传播公式解析
在反向传播中,我们也进行了简化,只使用logsumexp ( L ) ( L ) (L)。
反向传播的具体计算步骤
-
初始化:
d Q = ( 0 ) N × d dQ = (0)_{N \times d} dQ=(0)N×d
初始化 dQ 为零矩阵。 -
计算梯度:
D = rowsum ( d O ∘ O ) D = \text{rowsum}(dO \circ O) D=rowsum(dO∘O)
计算 dO 和 O 的点积的行和。
每个块的处理:
S
(
j
)
=
Q
i
K
j
T
S^{(j)} = Q_i K_j^T
S(j)=QiKjT
计算块
(
i
)
(i)
(i) 和
(
j
)
(j)
(j) 的乘积
(
S
(
j
)
)
(S^{(j)})
(S(j))。
P
(
j
)
=
e
S
(
j
)
−
L
i
P^{(j)} = e^{S^{(j)} - L_i}
P(j)=eS(j)−Li
计算
(
P
(
j
)
)
(P^{(j)})
(P(j)),即
(
S
(
j
)
)
(S^{(j)})
(S(j)) 减去
(
L
i
)
(L_i)
(Li) 后的指数。
d
V
j
←
d
V
j
+
(
P
(
j
)
)
T
d
O
i
dV_j \leftarrow dV_j + (P^{(j)})^T dO_i
dVj←dVj+(P(j))TdOi
更新 dV_j。
d
P
(
j
)
=
d
O
i
V
j
T
dP^{(j)} = dO_i V_j^T
dP(j)=dOiVjT
计算 dP^{(j)}。
d
S
(
j
)
=
P
(
j
)
∘
(
d
P
(
j
)
−
D
i
)
dS^{(j)} = P^{(j)} \circ (dP^{(j)} - D_i)
dS(j)=P(j)∘(dP(j)−Di)
计算 dS^{(j)}。
d
Q
i
←
d
Q
i
+
d
S
(
j
)
K
j
dQ_i \leftarrow dQ_i + dS^{(j)} K_j
dQi←dQi+dS(j)Kj
更新 dQ_i。
d
K
j
←
d
K
j
+
(
d
S
(
j
)
)
T
Q
i
dK_j \leftarrow dK_j + (dS^{(j)})^T Q_i
dKj←dKj+(dS(j))TQi
更新 dK_j。
通过这些步骤,我们完成了反向传播的计算,得到dQ、dK和dV。
多查询注意力和分组查询注意力
在这些变体中,我们将多个头的查询集中到同一个头的键和值上,以减少KV缓存的大小。在反向传播中,我们需要跨不同头对梯度dK和dV求和。