FlashInfer - FlashAttention 分块计算(Blockwise Computation) 和 IO 感知优化(IO-Aware Optimization)

FlashInfer - FlashAttention 分块计算(Blockwise Computation) 和 IO 感知优化(IO-Aware Optimization)

flyfish

在 FlashAttention 和注意力机制的上下文中,点积 指的是 矩阵乘法,但更具体地说是 查询矩阵 Q Q Q 与键矩阵 K K K 的转置 K T K^T KT 之间的矩阵乘法,在注意力机制的上下文中,点积(Dot Product) 通常指的就是 矩阵乘法

1. 数学定义

  • 向量点积:两个向量 a \mathbf{a} a b \mathbf{b} b 的点积为 a ⋅ b = ∑ i = 1 n a i b i \mathbf{a} \cdot \mathbf{b} = \sum_{i=1}^n a_i b_i ab=i=1naibi
  • 矩阵乘法:两个矩阵 A A A B B B 的乘法 C = A B C = AB C=AB 中,每个元素 C i , j C_{i,j} Ci,j A A A 的第 i i i 行与 B B B 的第 j j j 列的 点积

2. 在注意力机制中的应用

在自注意力计算 Attention ( Q , K , V ) = softmax ( Q K T d ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V Attention(Q,K,V)=softmax(d QKT)V 中:

  • Q K T QK^T QKT 的计算本质是 矩阵乘法,其中:
    • Q Q Q 是查询矩阵(形状为 batch_size × seq_len_q × head_dim \text{batch\_size} \times \text{seq\_len\_q} \times \text{head\_dim} batch_size×seq_len_q×head_dim)。
    • K K K 是键矩阵(形状为 batch_size × seq_len_k × head_dim \text{batch\_size} \times \text{seq\_len\_k} \times \text{head\_dim} batch_size×seq_len_k×head_dim)。
    • Q K T QK^T QKT 的结果是注意力得分矩阵(形状为 batch_size × seq_len_q × seq_len_k \text{batch\_size} \times \text{seq\_len\_q} \times \text{seq\_len\_k} batch_size×seq_len_q×seq_len_k)。
  • 这里的 点积 是指 Q Q Q 的每一行(查询向量)与 K K K 的每一行(键向量)之间的点积,通过矩阵乘法一次性计算所有组合的点积。

3. FlashAttention 中的优化

FlashAttention 的核心是 分块矩阵乘法,将大矩阵分解为小的块(block)进行计算:

  • 每个块的计算仍然是矩阵乘法(即点积的批量计算)。
  • 通过优化块的大小和内存访问模式,减少中间结果的存储,从而降低显存占用。

4. 术语混用的原因

  • 在深度学习领域,“点积” 和 “矩阵乘法” 常被互换使用,因为矩阵乘法本质上是多个点积的组合。
  • 在注意力机制的文献中,“点积注意力”(Dot-Product Attention)特指使用矩阵乘法 Q K T QK^T QKT 计算注意力得分的方法。

5. 示例对比

场景数学表示等价操作
向量点积 a ⋅ b \mathbf{a} \cdot \mathbf{b} ab对应元素相乘后求和
矩阵乘法 C = A B C = AB C=AB C i , j = ∑ k A i , k B k , j C_{i,j} = \sum_k A_{i,k} B_{k,j} Ci,j=kAi,kBk,j
注意力得分计算 Q K T QK^T QKT Q Q Q 的每行与 K K K 的每行的点积

FlashAttention 原理:从内存瓶颈到高效计算

FlashAttention 是针对长序列自注意力的高效实现,核心目标是解决传统注意力 O(N²) 的显存占用和计算低效问题。其核心原理可概括为 分块计算(Blockwise Computation)IO 感知优化(IO-Aware Optimization),具体如下:

1. 传统注意力的瓶颈

标准自注意力计算流程为:

  1. 计算 Query (Q) 与 Key (K) 的点积矩阵 Q K T QK^T QKT,得到注意力得分矩阵 S S S
  2. S S S 应用 softmax 得到归一化权重 A A A
  3. 计算 A A A 与 Value (V) 的加权和,得到输出 O O O

问题

  • 显存爆炸:存储 S S S A A A 需要 O ( N 2 d ) O(N²d) O(N2d) 显存(d 为头维度),当 N = 8 K N=8K N=8K d = 128 d=128 d=128 时,单头显存占用约 8GB,多卡并行时难以承受。
  • 计算低效:矩阵乘法和 softmax 的内存访问带宽成为瓶颈,GPU 计算资源利用率低。
2. FlashAttention 的核心优化原理

FlashAttention 通过 分块重排序 将显存占用降至 O(Nd),并提升计算效率,核心步骤如下:

(1)分块点积计算(Blockwise Dot Product)
  • 将 Q、K、V 按序列长度分成固定大小的块(如块大小 B = 128 B=128 B=128),假设序列长度为 N N N,则分为 M = N / B M = N/B M=N/B 块。
  • 对每个 Query 块 Q i Q_i Qi,依次与所有 Key 块 K j K_j Kj 计算点积 S i , j = Q i K j T S_{i,j} = Q_i K_j^T Si,j=QiKjT,而非一次性计算所有 Q K T QK^T QKT
  • 关键:计算完 S i , j S_{i,j} Si,j 后,立即对其进行 softmax 并与对应的 Value 块 V j V_j Vj 加权求和,无需存储完整的 S S S A A A
(2)softmax 归一化的跨块修正
  • 传统 softmax 对全局得分归一化,但分块计算时,每个块的得分范围不同,需通过 跨块最大值修正 保证归一化正确:
    1. 先计算所有块的得分最大值 max ⁡ ( S ) \max(S) max(S),用于调整每个块的得分偏移量。
    2. 对每个块 S i , j S_{i,j} Si,j 减去全局最大值,再计算指数和归一化,避免数值溢出。
(3)IO 感知的内存访问优化
  • 合并访存(Coalesced Memory Access):将数据按 GPU 线程块(thread block)的大小对齐,确保每个线程访问连续的内存地址,减少显存事务数。
  • 重用寄存器与共享内存:将高频访问的块数据存储在寄存器或共享内存中,减少对全局显存的依赖。
  • 计算与通信重叠:在计算当前块时,预取下一块数据到共享内存,隐藏内存延迟。
(4)支持因果掩码(Causal Mask)
  • 在生成场景(如解码阶段),需屏蔽未来位置的注意力。FlashAttention 通过在分块计算时,仅允许 Q i Q_i Qi K j K_j Kj j ≤ i j \leq i ji)的块进行交互,天然支持因果关系,无需额外掩码矩阵存储。

FlashInfer 对 FlashAttention 的实现与扩展

FlashInfer 基于 FlashAttention 原理,结合 GPU 硬件特性和 LLM 推理需求,实现了高性能内核,并增加了以下关键特性:

1. 硬件感知的内核生成
  • Tensor Core 优化:针对 NVIDIA GPU 的 Tensor Core(如 Volta 架构后的 FP16/FP8 矩阵运算单元),生成专用的块矩阵乘法内核,支持 混合精度计算(如 FP8 点积 + FP16 累加),在保持精度的同时提升计算吞吐量。
  • CUDA 线程调度:通过模板化代码动态调整线程块大小(如 128x128 线程块),适配不同 GPU 架构(A100、H100 等),最大化并行度。
2. 动态批处理与负载均衡
  • 解耦 Plan/Run 阶段
    • Plan 阶段:预处理可变长度输入的块划分策略,生成调度计划(如哪些块需要计算,如何分配线程块)。
    • Run 阶段:按计划并行执行块计算,避免传统方法中因输入长度不一致导致的 GPU 线程空闲。
  • 分页 KV 缓存(PageAttention 集成):将 KV 数据按页存储(非连续内存),通过页表索引快速定位块数据,支持超长序列的高效访问(如 16K+ 上下文)。
3. 内存效率增强
  • Head-Query 融合:在 Grouped-Query Attention(GQA)中,将多个 Query 头的计算融合到同一内核,减少显存访问次数。
  • Cascade Attention 分层缓存:对多层 Transformer,分层复用 KV 缓存,避免重复计算共享前缀,进一步降低显存占用。
4. 接口与集成
  • PyTorch API 封装:提供简单易用的接口(如 single_decode_with_kv_cachesingle_prefill_with_kv_cache),支持 RoPE 位置编码的动态计算,无需用户手动处理底层块划分。
  • JIT 编译与自定义:通过 TVM 绑定或 C++ 模板,允许用户自定义注意力变体(如修改块大小、掩码策略),动态生成专用内核,适配特殊场景。
5. 性能优化细节
  • FP8 支持:利用 NVIDIA 的 FP8 数据格式(如 E4M3、E5M2),将点积计算的显存带宽需求降低 50%,同时通过动态缩放保持数值稳定性。
  • CUDAGraph 与 torch.compile:内核可被捕获到 CUDA 图中,减少内核启动开销,适用于低延迟推理(如单 token 生成)。

FlashAttention 原理与 FlashInfer 实现的核心差异

维度FlashAttention 原理FlashInfer 实现扩展
核心目标长序列内存优化与计算加速LLM 推理全场景优化(解码/预填充、动态批处理)
硬件适配通用 GPU 优化Tensor Core/FP8 专用内核、多 GPU 架构支持
内存管理分块计算降低峰值显存分页 KV 缓存、Head-Query 融合、Cascade 分层
输入支持固定长度块处理可变长度输入的负载均衡调度
接口与生态基础算法实现PyTorch/TVM/C++ 多 API、JIT 自定义支持

关键公式与性能对比

  • 显存占用:传统注意力 O ( N 2 d ) O(N²d) O(N2d) → FlashAttention O ( N d M ) O(NdM) O(NdM)(M 为块数, M ≪ N M \ll N MN)。
  • 速度提升:在 N = 8192 N=8192 N=8192 时,FlashInfer 的 FlashAttention 内核比 PyTorch 原生实现快 4-6 倍,显存占用减少 70% 以上(数据来源:FlashInfer 基准测试)。

FlashInfer 让 FlashAttention 从理论优化变为可落地的高性能内核,成为 LLM 推理服务的关键基础设施。

1. 传统注意力机制的瓶颈

标准自注意力计算为:
Attention ( Q , K , V ) = softmax ( Q K T d ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V Attention(Q,K,V)=softmax(d QKT)V
其中:

  • Q ∈ R n × d Q \in \mathbb{R}^{n \times d} QRn×d(查询矩阵)
  • K ∈ R m × d K \in \mathbb{R}^{m \times d} KRm×d(键矩阵)
  • V ∈ R m × d V \in \mathbb{R}^{m \times d} VRm×d(值矩阵)
  • n n n 是目标序列长度, m m m 是源序列长度, d d d 是维度

问题

  • 内存瓶颈:计算 Q K T QK^T QKT 需要存储完整的注意力矩阵(大小为 n × m n \times m n×m),当 n n n m m m 很大时(如长文本),显存占用呈平方级增长。
  • 计算低效:即使使用 GPU 并行计算,大量时间浪费在内存读写(带宽瓶颈)而非真正的计算上。
2. FlashAttention 的核心创新

FlashAttention 通过以下技术突破传统瓶颈:

2.1 分块计算(Blockwise Computation)

Q Q Q K K K V V V 分成小的块(block),逐块计算注意力,避免一次性加载整个矩阵:

  1. 分块矩阵乘法:将 Q Q Q 分成块 Q i Q_i Qi K K K 分成块 K j K_j Kj,每次计算 Q i K j T Q_i K_j^T QiKjT 的一部分。
  2. 增量 softmax:在计算每个块的注意力得分时,同步更新 softmax 的归一化因子,避免存储完整的注意力矩阵。
2.2 IO-Aware 内存调度

通过精心设计的内存访问模式,最小化数据移动:

  • Tile 策略:将矩阵分块为更小的 tile(如 16x16),每个 tile 仅加载一次到寄存器或共享内存。
  • 读写平衡:每个数据块在计算后立即被重用,减少重复访问显存的开销。
2.3 张量核心优化

利用 GPU 的 Tensor Core 加速矩阵乘法,FlashAttention 针对 Tensor Core 的数据布局(如 16x16x16)进行优化,进一步提升计算效率。

3. 数学表达与伪代码

FlashAttention 的核心算法可简化为:

output = zeros(n, d)
l = zeros(n)  # 存储每行的 softmax 归一化因子
m = -inf * ones(n)  # 存储每行的最大值

for j in range(num_blocks):
    # 加载 K_j 和 V_j 块
    K_j = load_block(K, j)
    V_j = load_block(V, j)
    
    for i in range(num_blocks):
        # 加载 Q_i 块
        Q_i = load_block(Q, i)
        
        # 计算块内注意力得分
        scores = matmul(Q_i, K_j^T) / sqrt(d)
        
        # 更新最大值和归一化因子
        m_i_new = max(m[i], scores.max())
        l_i_new = l[i] * exp(m[i] - m_i_new) + scores.exp().sum(dim=1)
        
        # 计算加权值并更新输出
        p = exp(scores - m_i_new.unsqueeze(1))
        output[i] = (output[i] * l[i].unsqueeze(1) + matmul(p, V_j)) / l_i_new.unsqueeze(1)
        
        # 更新最大值和归一化因子
        m[i] = m_i_new
        l[i] = l_i_new

FlashInfer 对 FlashAttention 的实现

1. 架构设计

FlashInfer 的 FlashAttention 实现基于以下层次:

  • CUDA 内核层:直接编写优化的 CUDA 代码,利用线程块(thread block)和共享内存(shared memory)实现高效分块计算。
  • PyTorch 绑定层:通过 PyTorch C++ 扩展将 CUDA 内核封装为 Python 接口,无缝集成到深度学习框架中。
  • 高级 API 层:提供面向用户的简洁 API,如 single_decode_with_kv_cachebatch_decode_with_kv_cache
2. 关键优化技术
2.1 混合精度计算
  • 支持 FP16/FP8 精度,减少内存占用和计算量。
  • 针对不同 GPU 架构(如 A100、H100)优化 Tensor Core 使用。
2.2 KV 缓存优化
  • 持久化 KV 缓存:在生成过程中复用之前计算的键值对,避免重复计算。
  • 内存池管理:预分配固定大小的内存池,减少动态内存分配开销。
2.3 负载均衡调度
  • 将注意力计算分为 planrun 阶段:
    • plan 阶段:根据输入序列长度和 GPU 资源,动态规划最优的分块策略。
    • run 阶段:执行预规划的计算,减少运行时调度开销。
2.4 与其他 FlashInfer 技术协同
  • PageAttention:将 FlashAttention 与分页 KV 缓存结合,支持超长序列的高效处理。
  • Cascade Attention:实现多级 KV 缓存,优先处理高频访问的 token。
3. 代码示例

以下是 FlashInfer 中 FlashAttention API 的简化调用示例:

import torch
import flashinfer

# 初始化 KV 缓存
kv_cache = flashinfer.init_kv_cache(
    batch_size=16,
    max_seq_len=4096,
    num_heads=32,
    head_dim=128,
    dtype=torch.float16,
    device="cuda"
)

# 输入查询张量
q = torch.randn(16, 32, 128, dtype=torch.float16, device="cuda")

# 执行 FlashAttention
output = flashinfer.single_decode_with_kv_cache(
    q=q,
    kv_cache=kv_cache,
    seq_lens=torch.tensor([512] * 16, dtype=torch.int32, device="cuda")
)

FlashAttention 通过分块计算、IO-aware 内存调度和 Tensor Core 优化,突破了传统注意力机制的内存和计算瓶颈。FlashInfer 进一步将其集成到 LLM 推理框架中,结合 KV 缓存管理、混合精度计算和负载均衡,为长序列和高并发场景提供了极致优化的解决方案。

FlashAttention中分块矩阵乘法(点积的批量计算)的过程

示例:分块矩阵乘法计算注意力得分

假设我们有以下参数:

  • 序列长度 n = 8 n = 8 n=8(即8个token)
  • 头维度 d = 4 d = 4 d=4
  • 分块大小 b = 4 b = 4 b=4(将序列分为2块,每块4个token)
1. 原始矩阵表示

查询矩阵 Q Q Q 和键矩阵 K K K 分别为:
Q = [ q 1 q 2 q 3 q 4 q 5 q 6 q 7 q 8 ] , K = [ k 1 k 2 k 3 k 4 k 5 k 6 k 7 k 8 ] Q = \begin{bmatrix} q_1 \\ q_2 \\ q_3 \\ q_4 \\ q_5 \\ q_6 \\ q_7 \\ q_8 \end{bmatrix}, \quad K = \begin{bmatrix} k_1 \\ k_2 \\ k_3 \\ k_4 \\ k_5 \\ k_6 \\ k_7 \\ k_8 \end{bmatrix} Q= q1q2q3q4q5q6q7q8 ,K= k1k2k3k4k5k6k7k8
其中每个 q i , k i ∈ R 4 q_i, k_i \in \mathbb{R}^4 qi,kiR4

传统注意力得分计算
scores = Q K T = [ q 1 ⋅ k 1 q 1 ⋅ k 2 ⋯ q 1 ⋅ k 8 q 2 ⋅ k 1 q 2 ⋅ k 2 ⋯ q 2 ⋅ k 8 ⋮ ⋮ ⋱ ⋮ q 8 ⋅ k 1 q 8 ⋅ k 2 ⋯ q 8 ⋅ k 8 ] \text{scores} = QK^T = \begin{bmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 & \cdots & q_1 \cdot k_8 \\ q_2 \cdot k_1 & q_2 \cdot k_2 & \cdots & q_2 \cdot k_8 \\ \vdots & \vdots & \ddots & \vdots \\ q_8 \cdot k_1 & q_8 \cdot k_2 & \cdots & q_8 \cdot k_8 \end{bmatrix} scores=QKT= q1k1q2k1q8k1q1k2q2k2q8k2q1k8q2k8q8k8
这需要存储完整的 8 × 8 8 \times 8 8×8 矩阵,内存复杂度为 O ( n 2 ) O(n^2) O(n2)

2. FlashAttention的分块计算

Q Q Q K K K 分别分成2块:
Q = [ Q 1 Q 2 ] , K = [ K 1 K 2 ] Q = \begin{bmatrix} Q_1 \\ Q_2 \end{bmatrix}, \quad K = \begin{bmatrix} K_1 \\ K_2 \end{bmatrix} Q=[Q1Q2],K=[K1K2]
其中:
Q 1 = [ q 1 q 2 q 3 q 4 ] , Q 2 = [ q 5 q 6 q 7 q 8 ] , K 1 = [ k 1 k 2 k 3 k 4 ] , K 2 = [ k 5 k 6 k 7 k 8 ] Q_1 = \begin{bmatrix} q_1 \\ q_2 \\ q_3 \\ q_4 \end{bmatrix}, \quad Q_2 = \begin{bmatrix} q_5 \\ q_6 \\ q_7 \\ q_8 \end{bmatrix}, \quad K_1 = \begin{bmatrix} k_1 \\ k_2 \\ k_3 \\ k_4 \end{bmatrix}, \quad K_2 = \begin{bmatrix} k_5 \\ k_6 \\ k_7 \\ k_8 \end{bmatrix} Q1= q1q2q3q4 ,Q2= q5q6q7q8 ,K1= k1k2k3k4 ,K2= k5k6k7k8

分块计算步骤

步骤1:计算 Q 1 Q_1 Q1 K 1 K_1 K1 的得分

scores 1 , 1 = Q 1 K 1 T = [ q 1 ⋅ k 1 q 1 ⋅ k 2 q 1 ⋅ k 3 q 1 ⋅ k 4 q 2 ⋅ k 1 q 2 ⋅ k 2 q 2 ⋅ k 3 q 2 ⋅ k 4 q 3 ⋅ k 1 q 3 ⋅ k 2 q 3 ⋅ k 3 q 3 ⋅ k 4 q 4 ⋅ k 1 q 4 ⋅ k 2 q 4 ⋅ k 3 q 4 ⋅ k 4 ] \text{scores}_{1,1} = Q_1 K_1^T = \begin{bmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 & q_1 \cdot k_3 & q_1 \cdot k_4 \\ q_2 \cdot k_1 & q_2 \cdot k_2 & q_2 \cdot k_3 & q_2 \cdot k_4 \\ q_3 \cdot k_1 & q_3 \cdot k_2 & q_3 \cdot k_3 & q_3 \cdot k_4 \\ q_4 \cdot k_1 & q_4 \cdot k_2 & q_4 \cdot k_3 & q_4 \cdot k_4 \end{bmatrix} scores1,1=Q1K1T= q1k1q2k1q3k1q4k1q1k2q2k2q3k2q4k2q1k3q2k3q3k3q4k3q1k4q2k4q3k4q4k4

步骤2:计算 Q 1 Q_1 Q1 K 2 K_2 K2 的得分

scores 1 , 2 = Q 1 K 2 T = [ q 1 ⋅ k 5 q 1 ⋅ k 6 q 1 ⋅ k 7 q 1 ⋅ k 8 q 2 ⋅ k 5 q 2 ⋅ k 6 q 2 ⋅ k 7 q 2 ⋅ k 8 q 3 ⋅ k 5 q 3 ⋅ k 6 q 3 ⋅ k 7 q 3 ⋅ k 8 q 4 ⋅ k 5 q 4 ⋅ k 6 q 4 ⋅ k 7 q 4 ⋅ k 8 ] \text{scores}_{1,2} = Q_1 K_2^T = \begin{bmatrix} q_1 \cdot k_5 & q_1 \cdot k_6 & q_1 \cdot k_7 & q_1 \cdot k_8 \\ q_2 \cdot k_5 & q_2 \cdot k_6 & q_2 \cdot k_7 & q_2 \cdot k_8 \\ q_3 \cdot k_5 & q_3 \cdot k_6 & q_3 \cdot k_7 & q_3 \cdot k_8 \\ q_4 \cdot k_5 & q_4 \cdot k_6 & q_4 \cdot k_7 & q_4 \cdot k_8 \end{bmatrix} scores1,2=Q1K2T= q1k5q2k5q3k5q4k5q1k6q2k6q3k6q4k6q1k7q2k7q3k7q4k7q1k8q2k8q3k8q4k8

步骤3:计算 Q 2 Q_2 Q2 K 1 K_1 K1 的得分

scores 2 , 1 = Q 2 K 1 T = [ q 5 ⋅ k 1 q 5 ⋅ k 2 q 5 ⋅ k 3 q 5 ⋅ k 4 q 6 ⋅ k 1 q 6 ⋅ k 2 q 6 ⋅ k 3 q 6 ⋅ k 4 q 7 ⋅ k 1 q 7 ⋅ k 2 q 7 ⋅ k 3 q 7 ⋅ k 4 q 8 ⋅ k 1 q 8 ⋅ k 2 q 8 ⋅ k 3 q 8 ⋅ k 4 ] \text{scores}_{2,1} = Q_2 K_1^T = \begin{bmatrix} q_5 \cdot k_1 & q_5 \cdot k_2 & q_5 \cdot k_3 & q_5 \cdot k_4 \\ q_6 \cdot k_1 & q_6 \cdot k_2 & q_6 \cdot k_3 & q_6 \cdot k_4 \\ q_7 \cdot k_1 & q_7 \cdot k_2 & q_7 \cdot k_3 & q_7 \cdot k_4 \\ q_8 \cdot k_1 & q_8 \cdot k_2 & q_8 \cdot k_3 & q_8 \cdot k_4 \end{bmatrix} scores2,1=Q2K1T= q5k1q6k1q7k1q8k1q5k2q6k2q7k2q8k2q5k3q6k3q7k3q8k3q5k4q6k4q7k4q8k4

步骤4:计算 Q 2 Q_2 Q2 K 2 K_2 K2 的得分

scores 2 , 2 = Q 2 K 2 T = [ q 5 ⋅ k 5 q 5 ⋅ k 6 q 5 ⋅ k 7 q 5 ⋅ k 8 q 6 ⋅ k 5 q 6 ⋅ k 6 q 6 ⋅ k 7 q 6 ⋅ k 8 q 7 ⋅ k 5 q 7 ⋅ k 6 q 7 ⋅ k 7 q 7 ⋅ k 8 q 8 ⋅ k 5 q 8 ⋅ k 6 q 8 ⋅ k 7 q 8 ⋅ k 8 ] \text{scores}_{2,2} = Q_2 K_2^T = \begin{bmatrix} q_5 \cdot k_5 & q_5 \cdot k_6 & q_5 \cdot k_7 & q_5 \cdot k_8 \\ q_6 \cdot k_5 & q_6 \cdot k_6 & q_6 \cdot k_7 & q_6 \cdot k_8 \\ q_7 \cdot k_5 & q_7 \cdot k_6 & q_7 \cdot k_7 & q_7 \cdot k_8 \\ q_8 \cdot k_5 & q_8 \cdot k_6 & q_8 \cdot k_7 & q_8 \cdot k_8 \end{bmatrix} scores2,2=Q2K2T= q5k5q6k5q7k5q8k5q5k6q6k6q7k6q8k6q5k7q6k7q7k7q8k7q5k8q6k8q7k8q8k8

3. 内存优化分析
  • 传统方法:需要存储完整的 8 × 8 8 \times 8 8×8 得分矩阵,共 8 × 8 = 64 8 \times 8 = 64 8×8=64 个元素。
  • 分块方法:每次只需存储 4 × 4 = 16 4 \times 4 = 16 4×4=16 个元素(每个块的得分),内存峰值降低 4 4 4 倍。
4. 结合 softmax 和值矩阵计算

在实际计算中,每个块的得分会立即与对应的 softmax 和值矩阵 V V V 结合:

例如,对于块 Q 1 Q_1 Q1 K 1 K_1 K1

  1. 计算块内 softmax:
    weights 1 , 1 = softmax ( scores 1 , 1 ) \text{weights}_{1,1} = \text{softmax}(\text{scores}_{1,1}) weights1,1=softmax(scores1,1)
  2. 加载对应的值矩阵块 V 1 V_1 V1
    V 1 = [ v 1 v 2 v 3 v 4 ] V_1 = \begin{bmatrix} v_1 \\ v_2 \\ v_3 \\ v_4 \end{bmatrix} V1= v1v2v3v4
  3. 计算部分输出:
    output 1 ′ = weights 1 , 1 ⋅ V 1 \text{output}_1' = \text{weights}_{1,1} \cdot V_1 output1=weights1,1V1

类似地,处理 Q 1 Q_1 Q1 K 2 K_2 K2 的块:
output 1 ′ ′ = softmax ( scores 1 , 2 ) ⋅ V 2 \text{output}_1'' = \text{softmax}(\text{scores}_{1,2}) \cdot V_2 output1′′=softmax(scores1,2)V2

最终,块 Q 1 Q_1 Q1 的完整输出为:
output 1 = output 1 ′ + output 1 ′ ′ \text{output}_1 = \text{output}_1' + \text{output}_1'' output1=output1+output1′′

5. 数学公式变化

传统方法

output = softmax ( Q K T ) V \text{output} = \text{softmax}(QK^T)V output=softmax(QKT)V
分块方法
output i = ∑ j softmax ( scores i , j ) V j \text{output}_i = \sum_j \text{softmax}(\text{scores}_{i,j})V_j outputi=jsoftmax(scoresi,j)Vj
其中 scores i , j = Q i K j T \text{scores}_{i,j} = Q_i K_j^T scoresi,j=QiKjT 通过分块计算,FlashAttention将内存复杂度从 O ( n 2 ) O(n^2) O(n2) 降至 O ( n b ) O(nb) O(nb) b b b 为块大小),同时保持了计算结果的一致性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二分掌柜的

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值