PagedAttention 是一种针对 Transformer 模型(尤其是大语言模型)中 KV Cache(Key-Value Cache)管理的优化技术,旨在解决长序列推理时 KV Cache 的内存碎片化和高内存占用问题。它通过引入分页机制(类似于操作系统的虚拟内存分页),将 KV Cache 分块存储和管理,从而提高内存利用率、支持更长的序列长度,并提升推理服务的吞吐量和稳定性。以下是对 PagedAttention 的详细介绍,包括其背景、原理、实现方式、优缺点以及应用场景。
1. 背景
在 Transformer 模型的自回归推理中,KV Cache 用于存储注意力机制中的键(Key)和值(Value),以避免重复计算,从而加速生成过程。然而,KV Cache 存在以下挑战:
- 内存占用:KV Cache 的内存需求随序列长度线性增长,尤其在长序列或多用户场景下,显存占用可能成为瓶颈。
- 内存碎片化:动态生成的序列长度不固定(例如不同用户对话长度不同),导致 KV Cache 的内存分配不连续,产生碎片化,降低显存利用率。
- 动态调度难度:在推理服务中,多个请求可能同时处理,KV Cache 的分配和释放需要高效管理以支持高并发。
PagedAttention 通过将 KV Cache 分页存储(类似于操作系统的页面管理),将连续的 KV Cache 分成固定大小的块(Page),并通过映射表动态管理这些块,从而解决上述问题。它最初由 Meta AI 提出,并在 vLLM(一种高效推理引擎)中广泛应用。
2. PagedAttention 的工作原理
PagedAttention 的核心思想是将 KV Cache 的键和值张量分割成固定大小的块(称为“页面”),并通过一个映射表管理这些块的存储和访问。以下是其具体原理:
2.1 传统 KV Cache 的问题
在传统 KV Cache 中:
- 每个序列的 K K K 和 V V V 存储为连续的张量,形状为 ( b a t c h , h e a d s , s e q l e n , h e a d d i m ) (batch, heads, seq_len, head_dim) (batch,heads,seqlen,headdim)。
- 随着序列长度 s e q l e n seq_len seqlen 增加,KV Cache 需要动态扩展内存分配。
- 不同序列的长度不同,导致内存分配不均匀,容易产生碎片。
- 如果一个序列提前结束,其占用的显存可能无法立即被其他序列复用。
2.2 PagedAttention 的分页机制
PagedAttention 引入了以下关键概念:
- 页面(Page):KV Cache 被分割成固定大小的块,每个块存储一定数量的 token 的 K K K 和 V V V。例如,一个页面可能存储 512 个 token 的键和值。
- 页面大小(Page Size):每个页面能容纳的 token 数量,通常是固定的(如 128、256 或 512)。
- 块表(Block Table):每个序列维护一个块表,记录其 KV Cache 使用的页面编号(物理块索引)。
- 内存池:所有页面存储在一个全局内存池中,页面可以动态分配和释放,供不同序列复用。
2.3 工作流程
假设一个 Transformer 模型有 L L L 层, H H H 个注意力头,头维度为 d d d,序列长度为 T T T,PagedAttention 的工作流程如下:
- 初始化:
- 创建一个全局页面池,包含多个固定大小的页面,每个页面存储 N N N 个 token 的 K K K 和 V V V。
- 为每个序列分配一个块表,初始为空。
- 生成新 token:
- 当生成新 token 时,计算其 Q Q Q、 K K K、 V V V。
- 检查当前序列的最后一个页面是否已满:
- 如果未满,将新 token 的 K K K 和 V V V 写入当前页面。
- 如果已满,从内存池分配一个新页面,更新块表,将 K K K 和 V V V 写入新页面。
- 使用块表查找该序列的所有页面,获取完整的 K K K 和 V V V。
- 执行注意力计算:
Attention ( Q , K paged , V paged ) \text{Attention}(Q, K_{\text{paged}}, V_{\text{paged}}) Attention(Q,Kpaged,Vpaged)
- 页面管理:
- 当序列生成结束,释放其占用的页面,归还到内存池。
- 内存池支持动态分配,页面可以被其他序列复用。
2.4 存储结构
- 页面:每个页面存储 N N N 个 token 的 K K K 和 V V V,形状为 ( b a t c h , h e a d s , N , h e a d d i m ) (batch, heads, N, head_dim) (batch,heads,N,headdim)。
- 块表:一个数组,记录序列使用的页面编号。例如,序列 A 的块表可能是 [ 3 , 7 , 12 ] [3, 7, 12] [3,7,12],表示其 KV Cache 存储在页面 3、7 和 12 中。
- 内存池:一个大张量,形状为 ( n u m p a g e s , b a t c h , h e a d s , N , h e a d d i m ) (num_pages, batch, heads, N, head_dim) (numpages,batch,heads,N,headdim),存储所有页面。
2.5 内存访问优化
PagedAttention 依赖高效的内存访问机制:
- 非连续内存访问:通过块表,PagedAttention 可以从非连续的页面中拼接出完整的 K K K 和 V V V。
- CUDA 内核优化:在 GPU 上,PagedAttention 使用定制的 CUDA 内核,将页面拼接和注意力计算融合,减少内存拷贝开销。
3. PagedAttention 的优点
- 减少内存碎片:
- 通过固定大小的页面分配,PagedAttention 避免了动态分配导致的内存碎片化。
- 页面可以在不同序列间复用,提高显存利用率。
- 支持动态序列长度:
- 序列长度可以动态增长,只需分配新页面,无需预先分配连续的大块内存。
- 高效并发支持:
- 在推理服务中,多个请求可以共享页面池,PagedAttention 能高效管理多序列的 KV Cache。
- 支持长序列:
- 通过分页,PagedAttention 支持超长序列(例如 10k+ token),只需增加页面数量即可。
- 内存释放灵活:
- 序列结束后,页面可以立即归还到内存池,供其他序列使用。
4. PagedAttention 的缺点与挑战
- 额外的管理开销:
- 块表和页面池的管理需要额外计算和存储开销,尤其在高并发场景下。
- 内存访问复杂性:
- 非连续的页面访问可能导致缓存命中率降低,增加 GPU 内存访问延迟。
- 需要高度优化的 CUDA 内核来减少性能开销。
- 页面大小选择:
- 页面大小需要仔细调优:
- 页面过小:增加块表开销和内存访问次数。
- 页面过大:可能导致内存浪费(页面未满)。
- 页面大小需要仔细调优:
- 实现复杂性:
- 相比传统 KV Cache,PagedAttention 的实现更复杂,需要修改推理框架的内存管理和注意力计算逻辑。
5. 优化与扩展
PagedAttention 通常与其他优化技术结合使用,以进一步提升性能:
- 页面压缩:
- 对页面中的 K K K 和 V V V 应用量化(如 INT8)或稀疏化,减少内存占用。
- 预分配策略:
- 根据任务特点,预分配一定数量的页面,减少动态分配的开销。
- 异步内存管理:
- 使用异步分配和释放页面,隐藏内存管理延迟。
- 与滑动窗口结合:
- 结合滑动窗口注意力,只保留最近 N N N 个 token 的页面,减少长序列的内存需求。
6. 应用场景
PagedAttention 主要应用于高效推理场景,尤其是:
- 推理服务:如 vLLM、TGI(Text Generation Inference)等推理引擎,用于部署大模型(如 LLaMA、Mistral)。
- 长序列任务:如长文档生成、代码补全、长时间对话。
- 高并发场景:在云服务中,PagedAttention 支持多用户同时请求的高吞吐量推理。
- 资源受限环境:在显存有限的设备上,PagedAttention 能更高效地管理内存。
7. PagedAttention vs 传统 KV Cache
特性 | 传统 KV Cache | PagedAttention |
---|---|---|
内存分配 | 连续分配,动态扩展 | 分页分配,固定大小页面 |
内存碎片 | 容易产生碎片 | 减少碎片,页面复用 |
序列长度支持 | 受显存限制,难以支持超长序列 | 支持动态扩展,适合长序列 |
并发性能 | 多序列管理复杂 | 高效支持多序列,页面共享 |
实现复杂度 | 简单 | 较复杂,需要块表和定制内核 |
内存管理 | 手动释放或重新分配 | 动态分配和释放,内存池管理 |
8. 实现示例(伪代码)
以下是一个简化的 PagedAttention 伪代码(基于 PyTorch):
class PagedAttention:
def __init__(self, num_layers, num_heads, head_dim, page_size, max_pages):
self.page_size = page_size # 每个页面存储的 token 数量
self.memory_pool = torch.zeros((max_pages, num_layers, num_heads, page_size, head_dim)) # 页面池
self.block_tables = {} # 序列ID到页面编号的映射
self.free_pages = list(range(max_pages)) # 可用页面列表
def allocate_page(self, seq_id):
"""为序列分配新页面"""
if not self.free_pages:
raise RuntimeError("No free pages available")
page_idx = self.free_pages.pop(0)
if seq_id not in self.block_tables:
self.block_tables[seq_id] = []
self.block_tables[seq_id].append(page_idx)
return page_idx
def update_cache(self, seq_id, layer, K, V):
"""更新 KV Cache"""
if seq_id not in self.block_tables or not self.block_tables[seq_id]:
page_idx = self.allocate_page(seq_id)
else:
page_idx = self.block_tables[seq_id][-1]
# 检查当前页面是否已满
page_offset = len(self.block_tables[seq_id]) * self.page_size
if page_offset >= self.page_size:
page_idx = self.allocate_page(seq_id)
# 写入 K 和 V 到页面
self.memory_pool[page_idx, layer, :, page_offset % self.page_size, :] = K
self.memory_pool[page_idx, layer, :, page_offset % self.page_size, :] = V
def get_cache(self, seq_id, layer):
"""获取序列的 K 和 V"""
if seq_id not in self.block_tables:
return None, None
pages = self.block_tables[seq_id]
K_pages = [self.memory_pool[p, layer, :, :, :] for p in pages]
V_pages = [self.memory_pool[p, layer, :, :, :] for p in pages]
K = torch.cat(K_pages, dim=2) # 拼接页面
V = torch.cat(V_pages, dim=2)
return K, V
def attention(self, Q, K, V, seq_id, layer):
"""执行注意力计算"""
self.update_cache(seq_id, layer, K, V)
K_cached, V_cached = self.get_cache(seq_id, layer)
scores = torch.matmul(Q, K_cached.transpose(-1, -2)) / sqrt(Q.size(-1))
weights = torch.softmax(scores, dim=-1)
output = torch.matmul(weights, V_cached)
return output
9. 总结
PagedAttention 是一种高效的 KV Cache 管理技术,通过将键和值分页存储,解决了传统 KV Cache 的内存碎片化和动态分配问题。它在大模型推理服务中(如 vLLM)得到广泛应用,特别适合长序列、高并发和资源受限场景。尽管 PagedAttention 增加了实现复杂性和管理开销,但其在内存效率和推理性能上的优势使其成为现代推理引擎的核心组件。