FlashInfer - FlashAttention 分块计算(Blockwise Computation) 和 IO 感知优化(IO-Aware Optimization)
flyfish
在 FlashAttention 和注意力机制的上下文中,点积 指的是 矩阵乘法,但更具体地说是 查询矩阵 Q Q Q 与键矩阵 K K K 的转置 K T K^T KT 之间的矩阵乘法,在注意力机制的上下文中,点积(Dot Product) 通常指的就是 矩阵乘法。
1. 数学定义
- 向量点积:两个向量 a \mathbf{a} a 和 b \mathbf{b} b 的点积为 a ⋅ b = ∑ i = 1 n a i b i \mathbf{a} \cdot \mathbf{b} = \sum_{i=1}^n a_i b_i a⋅b=∑i=1naibi。
- 矩阵乘法:两个矩阵 A A A 和 B B B 的乘法 C = A B C = AB C=AB 中,每个元素 C i , j C_{i,j} Ci,j 是 A A A 的第 i i i 行与 B B B 的第 j j j 列的 点积。
2. 在注意力机制中的应用
在自注意力计算 Attention ( Q , K , V ) = softmax ( Q K T d ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V Attention(Q,K,V)=softmax(dQKT)V 中:
-
Q
K
T
QK^T
QKT 的计算本质是 矩阵乘法,其中:
- Q Q Q 是查询矩阵(形状为 batch_size × seq_len_q × head_dim \text{batch\_size} \times \text{seq\_len\_q} \times \text{head\_dim} batch_size×seq_len_q×head_dim)。
- K K K 是键矩阵(形状为 batch_size × seq_len_k × head_dim \text{batch\_size} \times \text{seq\_len\_k} \times \text{head\_dim} batch_size×seq_len_k×head_dim)。
- Q K T QK^T QKT 的结果是注意力得分矩阵(形状为 batch_size × seq_len_q × seq_len_k \text{batch\_size} \times \text{seq\_len\_q} \times \text{seq\_len\_k} batch_size×seq_len_q×seq_len_k)。
- 这里的 点积 是指 Q Q Q 的每一行(查询向量)与 K K K 的每一行(键向量)之间的点积,通过矩阵乘法一次性计算所有组合的点积。
3. FlashAttention 中的优化
FlashAttention 的核心是 分块矩阵乘法,将大矩阵分解为小的块(block)进行计算:
- 每个块的计算仍然是矩阵乘法(即点积的批量计算)。
- 通过优化块的大小和内存访问模式,减少中间结果的存储,从而降低显存占用。
4. 术语混用的原因
- 在深度学习领域,“点积” 和 “矩阵乘法” 常被互换使用,因为矩阵乘法本质上是多个点积的组合。
- 在注意力机制的文献中,“点积注意力”(Dot-Product Attention)特指使用矩阵乘法 Q K T QK^T QKT 计算注意力得分的方法。
5. 示例对比
场景 | 数学表示 | 等价操作 |
---|---|---|
向量点积 | a ⋅ b \mathbf{a} \cdot \mathbf{b} a⋅b | 对应元素相乘后求和 |
矩阵乘法 | C = A B C = AB C=AB | C i , j = ∑ k A i , k B k , j C_{i,j} = \sum_k A_{i,k} B_{k,j} Ci,j=∑kAi,kBk,j |
注意力得分计算 | Q K T QK^T QKT | Q Q Q 的每行与 K K K 的每行的点积 |
FlashAttention 原理:从内存瓶颈到高效计算
FlashAttention 是针对长序列自注意力的高效实现,核心目标是解决传统注意力 O(N²) 的显存占用和计算低效问题。其核心原理可概括为 分块计算(Blockwise Computation) 和 IO 感知优化(IO-Aware Optimization),具体如下:
1. 传统注意力的瓶颈
标准自注意力计算流程为:
- 计算 Query (Q) 与 Key (K) 的点积矩阵 Q K T QK^T QKT,得到注意力得分矩阵 S S S。
- 对 S S S 应用 softmax 得到归一化权重 A A A。
- 计算 A A A 与 Value (V) 的加权和,得到输出 O O O。
问题:
- 显存爆炸:存储 S S S 和 A A A 需要 O ( N 2 d ) O(N²d) O(N2d) 显存(d 为头维度),当 N = 8 K N=8K N=8K、 d = 128 d=128 d=128 时,单头显存占用约 8GB,多卡并行时难以承受。
- 计算低效:矩阵乘法和 softmax 的内存访问带宽成为瓶颈,GPU 计算资源利用率低。
2. FlashAttention 的核心优化原理
FlashAttention 通过 分块重排序 将显存占用降至 O(Nd),并提升计算效率,核心步骤如下:
(1)分块点积计算(Blockwise Dot Product)
- 将 Q、K、V 按序列长度分成固定大小的块(如块大小 B = 128 B=128 B=128),假设序列长度为 N N N,则分为 M = N / B M = N/B M=N/B 块。
- 对每个 Query 块 Q i Q_i Qi,依次与所有 Key 块 K j K_j Kj 计算点积 S i , j = Q i K j T S_{i,j} = Q_i K_j^T Si,j=QiKjT,而非一次性计算所有 Q K T QK^T QKT。
- 关键:计算完 S i , j S_{i,j} Si,j 后,立即对其进行 softmax 并与对应的 Value 块 V j V_j Vj 加权求和,无需存储完整的 S S S 和 A A A。
(2)softmax 归一化的跨块修正
- 传统 softmax 对全局得分归一化,但分块计算时,每个块的得分范围不同,需通过 跨块最大值修正 保证归一化正确:
- 先计算所有块的得分最大值 max ( S ) \max(S) max(S),用于调整每个块的得分偏移量。
- 对每个块 S i , j S_{i,j} Si,j 减去全局最大值,再计算指数和归一化,避免数值溢出。
(3)IO 感知的内存访问优化
- 合并访存(Coalesced Memory Access):将数据按 GPU 线程块(thread block)的大小对齐,确保每个线程访问连续的内存地址,减少显存事务数。
- 重用寄存器与共享内存:将高频访问的块数据存储在寄存器或共享内存中,减少对全局显存的依赖。
- 计算与通信重叠:在计算当前块时,预取下一块数据到共享内存,隐藏内存延迟。
(4)支持因果掩码(Causal Mask)
- 在生成场景(如解码阶段),需屏蔽未来位置的注意力。FlashAttention 通过在分块计算时,仅允许 Q i Q_i Qi 与 K j K_j Kj( j ≤ i j \leq i j≤i)的块进行交互,天然支持因果关系,无需额外掩码矩阵存储。
FlashInfer 对 FlashAttention 的实现与扩展
FlashInfer 基于 FlashAttention 原理,结合 GPU 硬件特性和 LLM 推理需求,实现了高性能内核,并增加了以下关键特性:
1. 硬件感知的内核生成
- Tensor Core 优化:针对 NVIDIA GPU 的 Tensor Core(如 Volta 架构后的 FP16/FP8 矩阵运算单元),生成专用的块矩阵乘法内核,支持 混合精度计算(如 FP8 点积 + FP16 累加),在保持精度的同时提升计算吞吐量。
- CUDA 线程调度:通过模板化代码动态调整线程块大小(如 128x128 线程块),适配不同 GPU 架构(A100、H100 等),最大化并行度。
2. 动态批处理与负载均衡
- 解耦 Plan/Run 阶段:
- Plan 阶段:预处理可变长度输入的块划分策略,生成调度计划(如哪些块需要计算,如何分配线程块)。
- Run 阶段:按计划并行执行块计算,避免传统方法中因输入长度不一致导致的 GPU 线程空闲。
- 分页 KV 缓存(PageAttention 集成):将 KV 数据按页存储(非连续内存),通过页表索引快速定位块数据,支持超长序列的高效访问(如 16K+ 上下文)。
3. 内存效率增强
- Head-Query 融合:在 Grouped-Query Attention(GQA)中,将多个 Query 头的计算融合到同一内核,减少显存访问次数。
- Cascade Attention 分层缓存:对多层 Transformer,分层复用 KV 缓存,避免重复计算共享前缀,进一步降低显存占用。
4. 接口与集成
- PyTorch API 封装:提供简单易用的接口(如
single_decode_with_kv_cache
、single_prefill_with_kv_cache
),支持 RoPE 位置编码的动态计算,无需用户手动处理底层块划分。 - JIT 编译与自定义:通过 TVM 绑定或 C++ 模板,允许用户自定义注意力变体(如修改块大小、掩码策略),动态生成专用内核,适配特殊场景。
5. 性能优化细节
- FP8 支持:利用 NVIDIA 的 FP8 数据格式(如 E4M3、E5M2),将点积计算的显存带宽需求降低 50%,同时通过动态缩放保持数值稳定性。
- CUDAGraph 与 torch.compile:内核可被捕获到 CUDA 图中,减少内核启动开销,适用于低延迟推理(如单 token 生成)。
FlashAttention 原理与 FlashInfer 实现的核心差异
维度 | FlashAttention 原理 | FlashInfer 实现扩展 |
---|---|---|
核心目标 | 长序列内存优化与计算加速 | LLM 推理全场景优化(解码/预填充、动态批处理) |
硬件适配 | 通用 GPU 优化 | Tensor Core/FP8 专用内核、多 GPU 架构支持 |
内存管理 | 分块计算降低峰值显存 | 分页 KV 缓存、Head-Query 融合、Cascade 分层 |
输入支持 | 固定长度块处理 | 可变长度输入的负载均衡调度 |
接口与生态 | 基础算法实现 | PyTorch/TVM/C++ 多 API、JIT 自定义支持 |
关键公式与性能对比
- 显存占用:传统注意力 O ( N 2 d ) O(N²d) O(N2d) → FlashAttention O ( N d M ) O(NdM) O(NdM)(M 为块数, M ≪ N M \ll N M≪N)。
- 速度提升:在 N = 8192 N=8192 N=8192 时,FlashInfer 的 FlashAttention 内核比 PyTorch 原生实现快 4-6 倍,显存占用减少 70% 以上(数据来源:FlashInfer 基准测试)。
FlashInfer 让 FlashAttention 从理论优化变为可落地的高性能内核,成为 LLM 推理服务的关键基础设施。
1. 传统注意力机制的瓶颈
标准自注意力计算为:
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
)
V
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
Attention(Q,K,V)=softmax(dQKT)V
其中:
- Q ∈ R n × d Q \in \mathbb{R}^{n \times d} Q∈Rn×d(查询矩阵)
- K ∈ R m × d K \in \mathbb{R}^{m \times d} K∈Rm×d(键矩阵)
- V ∈ R m × d V \in \mathbb{R}^{m \times d} V∈Rm×d(值矩阵)
- n n n 是目标序列长度, m m m 是源序列长度, d d d 是维度
问题:
- 内存瓶颈:计算 Q K T QK^T QKT 需要存储完整的注意力矩阵(大小为 n × m n \times m n×m),当 n n n 或 m m m 很大时(如长文本),显存占用呈平方级增长。
- 计算低效:即使使用 GPU 并行计算,大量时间浪费在内存读写(带宽瓶颈)而非真正的计算上。
2. FlashAttention 的核心创新
FlashAttention 通过以下技术突破传统瓶颈:
2.1 分块计算(Blockwise Computation)
将 Q Q Q、 K K K、 V V V 分成小的块(block),逐块计算注意力,避免一次性加载整个矩阵:
- 分块矩阵乘法:将 Q Q Q 分成块 Q i Q_i Qi, K K K 分成块 K j K_j Kj,每次计算 Q i K j T Q_i K_j^T QiKjT 的一部分。
- 增量 softmax:在计算每个块的注意力得分时,同步更新 softmax 的归一化因子,避免存储完整的注意力矩阵。
2.2 IO-Aware 内存调度
通过精心设计的内存访问模式,最小化数据移动:
- Tile 策略:将矩阵分块为更小的 tile(如 16x16),每个 tile 仅加载一次到寄存器或共享内存。
- 读写平衡:每个数据块在计算后立即被重用,减少重复访问显存的开销。
2.3 张量核心优化
利用 GPU 的 Tensor Core 加速矩阵乘法,FlashAttention 针对 Tensor Core 的数据布局(如 16x16x16)进行优化,进一步提升计算效率。
3. 数学表达与伪代码
FlashAttention 的核心算法可简化为:
output = zeros(n, d)
l = zeros(n) # 存储每行的 softmax 归一化因子
m = -inf * ones(n) # 存储每行的最大值
for j in range(num_blocks):
# 加载 K_j 和 V_j 块
K_j = load_block(K, j)
V_j = load_block(V, j)
for i in range(num_blocks):
# 加载 Q_i 块
Q_i = load_block(Q, i)
# 计算块内注意力得分
scores = matmul(Q_i, K_j^T) / sqrt(d)
# 更新最大值和归一化因子
m_i_new = max(m[i], scores.max())
l_i_new = l[i] * exp(m[i] - m_i_new) + scores.exp().sum(dim=1)
# 计算加权值并更新输出
p = exp(scores - m_i_new.unsqueeze(1))
output[i] = (output[i] * l[i].unsqueeze(1) + matmul(p, V_j)) / l_i_new.unsqueeze(1)
# 更新最大值和归一化因子
m[i] = m_i_new
l[i] = l_i_new
FlashInfer 对 FlashAttention 的实现
1. 架构设计
FlashInfer 的 FlashAttention 实现基于以下层次:
- CUDA 内核层:直接编写优化的 CUDA 代码,利用线程块(thread block)和共享内存(shared memory)实现高效分块计算。
- PyTorch 绑定层:通过 PyTorch C++ 扩展将 CUDA 内核封装为 Python 接口,无缝集成到深度学习框架中。
- 高级 API 层:提供面向用户的简洁 API,如
single_decode_with_kv_cache
或batch_decode_with_kv_cache
。
2. 关键优化技术
2.1 混合精度计算
- 支持 FP16/FP8 精度,减少内存占用和计算量。
- 针对不同 GPU 架构(如 A100、H100)优化 Tensor Core 使用。
2.2 KV 缓存优化
- 持久化 KV 缓存:在生成过程中复用之前计算的键值对,避免重复计算。
- 内存池管理:预分配固定大小的内存池,减少动态内存分配开销。
2.3 负载均衡调度
- 将注意力计算分为
plan
和run
阶段:plan
阶段:根据输入序列长度和 GPU 资源,动态规划最优的分块策略。run
阶段:执行预规划的计算,减少运行时调度开销。
2.4 与其他 FlashInfer 技术协同
- PageAttention:将 FlashAttention 与分页 KV 缓存结合,支持超长序列的高效处理。
- Cascade Attention:实现多级 KV 缓存,优先处理高频访问的 token。
3. 代码示例
以下是 FlashInfer 中 FlashAttention API 的简化调用示例:
import torch
import flashinfer
# 初始化 KV 缓存
kv_cache = flashinfer.init_kv_cache(
batch_size=16,
max_seq_len=4096,
num_heads=32,
head_dim=128,
dtype=torch.float16,
device="cuda"
)
# 输入查询张量
q = torch.randn(16, 32, 128, dtype=torch.float16, device="cuda")
# 执行 FlashAttention
output = flashinfer.single_decode_with_kv_cache(
q=q,
kv_cache=kv_cache,
seq_lens=torch.tensor([512] * 16, dtype=torch.int32, device="cuda")
)
FlashAttention 通过分块计算、IO-aware 内存调度和 Tensor Core 优化,突破了传统注意力机制的内存和计算瓶颈。FlashInfer 进一步将其集成到 LLM 推理框架中,结合 KV 缓存管理、混合精度计算和负载均衡,为长序列和高并发场景提供了极致优化的解决方案。
FlashAttention中分块矩阵乘法(点积的批量计算)的过程
示例:分块矩阵乘法计算注意力得分
假设我们有以下参数:
- 序列长度 n = 8 n = 8 n=8(即8个token)
- 头维度 d = 4 d = 4 d=4
- 分块大小 b = 4 b = 4 b=4(将序列分为2块,每块4个token)
1. 原始矩阵表示
查询矩阵
Q
Q
Q 和键矩阵
K
K
K 分别为:
Q
=
[
q
1
q
2
q
3
q
4
q
5
q
6
q
7
q
8
]
,
K
=
[
k
1
k
2
k
3
k
4
k
5
k
6
k
7
k
8
]
Q = \begin{bmatrix} q_1 \\ q_2 \\ q_3 \\ q_4 \\ q_5 \\ q_6 \\ q_7 \\ q_8 \end{bmatrix}, \quad K = \begin{bmatrix} k_1 \\ k_2 \\ k_3 \\ k_4 \\ k_5 \\ k_6 \\ k_7 \\ k_8 \end{bmatrix}
Q=
q1q2q3q4q5q6q7q8
,K=
k1k2k3k4k5k6k7k8
其中每个
q
i
,
k
i
∈
R
4
q_i, k_i \in \mathbb{R}^4
qi,ki∈R4。
传统注意力得分计算:
scores
=
Q
K
T
=
[
q
1
⋅
k
1
q
1
⋅
k
2
⋯
q
1
⋅
k
8
q
2
⋅
k
1
q
2
⋅
k
2
⋯
q
2
⋅
k
8
⋮
⋮
⋱
⋮
q
8
⋅
k
1
q
8
⋅
k
2
⋯
q
8
⋅
k
8
]
\text{scores} = QK^T = \begin{bmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 & \cdots & q_1 \cdot k_8 \\ q_2 \cdot k_1 & q_2 \cdot k_2 & \cdots & q_2 \cdot k_8 \\ \vdots & \vdots & \ddots & \vdots \\ q_8 \cdot k_1 & q_8 \cdot k_2 & \cdots & q_8 \cdot k_8 \end{bmatrix}
scores=QKT=
q1⋅k1q2⋅k1⋮q8⋅k1q1⋅k2q2⋅k2⋮q8⋅k2⋯⋯⋱⋯q1⋅k8q2⋅k8⋮q8⋅k8
这需要存储完整的
8
×
8
8 \times 8
8×8 矩阵,内存复杂度为
O
(
n
2
)
O(n^2)
O(n2)。
2. FlashAttention的分块计算
将
Q
Q
Q 和
K
K
K 分别分成2块:
Q
=
[
Q
1
Q
2
]
,
K
=
[
K
1
K
2
]
Q = \begin{bmatrix} Q_1 \\ Q_2 \end{bmatrix}, \quad K = \begin{bmatrix} K_1 \\ K_2 \end{bmatrix}
Q=[Q1Q2],K=[K1K2]
其中:
Q
1
=
[
q
1
q
2
q
3
q
4
]
,
Q
2
=
[
q
5
q
6
q
7
q
8
]
,
K
1
=
[
k
1
k
2
k
3
k
4
]
,
K
2
=
[
k
5
k
6
k
7
k
8
]
Q_1 = \begin{bmatrix} q_1 \\ q_2 \\ q_3 \\ q_4 \end{bmatrix}, \quad Q_2 = \begin{bmatrix} q_5 \\ q_6 \\ q_7 \\ q_8 \end{bmatrix}, \quad K_1 = \begin{bmatrix} k_1 \\ k_2 \\ k_3 \\ k_4 \end{bmatrix}, \quad K_2 = \begin{bmatrix} k_5 \\ k_6 \\ k_7 \\ k_8 \end{bmatrix}
Q1=
q1q2q3q4
,Q2=
q5q6q7q8
,K1=
k1k2k3k4
,K2=
k5k6k7k8
分块计算步骤:
步骤1:计算 Q 1 Q_1 Q1 与 K 1 K_1 K1 的得分
scores 1 , 1 = Q 1 K 1 T = [ q 1 ⋅ k 1 q 1 ⋅ k 2 q 1 ⋅ k 3 q 1 ⋅ k 4 q 2 ⋅ k 1 q 2 ⋅ k 2 q 2 ⋅ k 3 q 2 ⋅ k 4 q 3 ⋅ k 1 q 3 ⋅ k 2 q 3 ⋅ k 3 q 3 ⋅ k 4 q 4 ⋅ k 1 q 4 ⋅ k 2 q 4 ⋅ k 3 q 4 ⋅ k 4 ] \text{scores}_{1,1} = Q_1 K_1^T = \begin{bmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 & q_1 \cdot k_3 & q_1 \cdot k_4 \\ q_2 \cdot k_1 & q_2 \cdot k_2 & q_2 \cdot k_3 & q_2 \cdot k_4 \\ q_3 \cdot k_1 & q_3 \cdot k_2 & q_3 \cdot k_3 & q_3 \cdot k_4 \\ q_4 \cdot k_1 & q_4 \cdot k_2 & q_4 \cdot k_3 & q_4 \cdot k_4 \end{bmatrix} scores1,1=Q1K1T= q1⋅k1q2⋅k1q3⋅k1q4⋅k1q1⋅k2q2⋅k2q3⋅k2q4⋅k2q1⋅k3q2⋅k3q3⋅k3q4⋅k3q1⋅k4q2⋅k4q3⋅k4q4⋅k4
步骤2:计算 Q 1 Q_1 Q1 与 K 2 K_2 K2 的得分
scores 1 , 2 = Q 1 K 2 T = [ q 1 ⋅ k 5 q 1 ⋅ k 6 q 1 ⋅ k 7 q 1 ⋅ k 8 q 2 ⋅ k 5 q 2 ⋅ k 6 q 2 ⋅ k 7 q 2 ⋅ k 8 q 3 ⋅ k 5 q 3 ⋅ k 6 q 3 ⋅ k 7 q 3 ⋅ k 8 q 4 ⋅ k 5 q 4 ⋅ k 6 q 4 ⋅ k 7 q 4 ⋅ k 8 ] \text{scores}_{1,2} = Q_1 K_2^T = \begin{bmatrix} q_1 \cdot k_5 & q_1 \cdot k_6 & q_1 \cdot k_7 & q_1 \cdot k_8 \\ q_2 \cdot k_5 & q_2 \cdot k_6 & q_2 \cdot k_7 & q_2 \cdot k_8 \\ q_3 \cdot k_5 & q_3 \cdot k_6 & q_3 \cdot k_7 & q_3 \cdot k_8 \\ q_4 \cdot k_5 & q_4 \cdot k_6 & q_4 \cdot k_7 & q_4 \cdot k_8 \end{bmatrix} scores1,2=Q1K2T= q1⋅k5q2⋅k5q3⋅k5q4⋅k5q1⋅k6q2⋅k6q3⋅k6q4⋅k6q1⋅k7q2⋅k7q3⋅k7q4⋅k7q1⋅k8q2⋅k8q3⋅k8q4⋅k8
步骤3:计算 Q 2 Q_2 Q2 与 K 1 K_1 K1 的得分
scores 2 , 1 = Q 2 K 1 T = [ q 5 ⋅ k 1 q 5 ⋅ k 2 q 5 ⋅ k 3 q 5 ⋅ k 4 q 6 ⋅ k 1 q 6 ⋅ k 2 q 6 ⋅ k 3 q 6 ⋅ k 4 q 7 ⋅ k 1 q 7 ⋅ k 2 q 7 ⋅ k 3 q 7 ⋅ k 4 q 8 ⋅ k 1 q 8 ⋅ k 2 q 8 ⋅ k 3 q 8 ⋅ k 4 ] \text{scores}_{2,1} = Q_2 K_1^T = \begin{bmatrix} q_5 \cdot k_1 & q_5 \cdot k_2 & q_5 \cdot k_3 & q_5 \cdot k_4 \\ q_6 \cdot k_1 & q_6 \cdot k_2 & q_6 \cdot k_3 & q_6 \cdot k_4 \\ q_7 \cdot k_1 & q_7 \cdot k_2 & q_7 \cdot k_3 & q_7 \cdot k_4 \\ q_8 \cdot k_1 & q_8 \cdot k_2 & q_8 \cdot k_3 & q_8 \cdot k_4 \end{bmatrix} scores2,1=Q2K1T= q5⋅k1q6⋅k1q7⋅k1q8⋅k1q5⋅k2q6⋅k2q7⋅k2q8⋅k2q5⋅k3q6⋅k3q7⋅k3q8⋅k3q5⋅k4q6⋅k4q7⋅k4q8⋅k4
步骤4:计算 Q 2 Q_2 Q2 与 K 2 K_2 K2 的得分
scores 2 , 2 = Q 2 K 2 T = [ q 5 ⋅ k 5 q 5 ⋅ k 6 q 5 ⋅ k 7 q 5 ⋅ k 8 q 6 ⋅ k 5 q 6 ⋅ k 6 q 6 ⋅ k 7 q 6 ⋅ k 8 q 7 ⋅ k 5 q 7 ⋅ k 6 q 7 ⋅ k 7 q 7 ⋅ k 8 q 8 ⋅ k 5 q 8 ⋅ k 6 q 8 ⋅ k 7 q 8 ⋅ k 8 ] \text{scores}_{2,2} = Q_2 K_2^T = \begin{bmatrix} q_5 \cdot k_5 & q_5 \cdot k_6 & q_5 \cdot k_7 & q_5 \cdot k_8 \\ q_6 \cdot k_5 & q_6 \cdot k_6 & q_6 \cdot k_7 & q_6 \cdot k_8 \\ q_7 \cdot k_5 & q_7 \cdot k_6 & q_7 \cdot k_7 & q_7 \cdot k_8 \\ q_8 \cdot k_5 & q_8 \cdot k_6 & q_8 \cdot k_7 & q_8 \cdot k_8 \end{bmatrix} scores2,2=Q2K2T= q5⋅k5q6⋅k5q7⋅k5q8⋅k5q5⋅k6q6⋅k6q7⋅k6q8⋅k6q5⋅k7q6⋅k7q7⋅k7q8⋅k7q5⋅k8q6⋅k8q7⋅k8q8⋅k8
3. 内存优化分析
- 传统方法:需要存储完整的 8 × 8 8 \times 8 8×8 得分矩阵,共 8 × 8 = 64 8 \times 8 = 64 8×8=64 个元素。
- 分块方法:每次只需存储 4 × 4 = 16 4 \times 4 = 16 4×4=16 个元素(每个块的得分),内存峰值降低 4 4 4 倍。
4. 结合 softmax 和值矩阵计算
在实际计算中,每个块的得分会立即与对应的 softmax 和值矩阵 V V V 结合:
例如,对于块 Q 1 Q_1 Q1 和 K 1 K_1 K1:
- 计算块内 softmax:
weights 1 , 1 = softmax ( scores 1 , 1 ) \text{weights}_{1,1} = \text{softmax}(\text{scores}_{1,1}) weights1,1=softmax(scores1,1) - 加载对应的值矩阵块
V
1
V_1
V1:
V 1 = [ v 1 v 2 v 3 v 4 ] V_1 = \begin{bmatrix} v_1 \\ v_2 \\ v_3 \\ v_4 \end{bmatrix} V1= v1v2v3v4 - 计算部分输出:
output 1 ′ = weights 1 , 1 ⋅ V 1 \text{output}_1' = \text{weights}_{1,1} \cdot V_1 output1′=weights1,1⋅V1
类似地,处理
Q
1
Q_1
Q1 与
K
2
K_2
K2 的块:
output
1
′
′
=
softmax
(
scores
1
,
2
)
⋅
V
2
\text{output}_1'' = \text{softmax}(\text{scores}_{1,2}) \cdot V_2
output1′′=softmax(scores1,2)⋅V2
最终,块
Q
1
Q_1
Q1 的完整输出为:
output
1
=
output
1
′
+
output
1
′
′
\text{output}_1 = \text{output}_1' + \text{output}_1''
output1=output1′+output1′′
5. 数学公式变化
传统方法
output
=
softmax
(
Q
K
T
)
V
\text{output} = \text{softmax}(QK^T)V
output=softmax(QKT)V
分块方法
output
i
=
∑
j
softmax
(
scores
i
,
j
)
V
j
\text{output}_i = \sum_j \text{softmax}(\text{scores}_{i,j})V_j
outputi=j∑softmax(scoresi,j)Vj
其中
scores
i
,
j
=
Q
i
K
j
T
\text{scores}_{i,j} = Q_i K_j^T
scoresi,j=QiKjT 通过分块计算,FlashAttention将内存复杂度从
O
(
n
2
)
O(n^2)
O(n2) 降至
O
(
n
b
)
O(nb)
O(nb)(
b
b
b 为块大小),同时保持了计算结果的一致性。