大模型推理优化之 KV Cache

文章介绍了KVCache在语言模型推理中的应用,通过缓存重复数据加速解码过程,提升了GPT-4等模型的推理速度。同时,分析了KVCache的显存占用,讨论了MultiQueryAttention和GroupedQueryAttention两种优化策略。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原文:https://zhuanlan.zhihu.com/p/677660376

目录

收起

KV Cache 定义

KV Cache 原理

KV Cache 实现

KV Cache 显存占用分析

KV Cache 优化方法

在语言模型推理的过程中,性能优化一直是一个备受关注的话题。LLM(Large Language Models)的出现使得自然语言处理取得了显著的进展,但随之而来的是庞大的模型和复杂的计算过程,因此推理效率的提升变得至关重要。在这个背景下,KV Cache(键-值缓存)成为了一项被广泛应用的推理优化技术。

KV Cache 定义

KV Cache,即键-值缓存,是一种用于存储键值对数据的缓存机制。在语言模型的推理过程中,经常需要多次访问相同的数据,而KV Cache通过将这些数据缓存到内存中,提供了快速的数据访问速度,从而加速推理过程。该技术仅应用于解码阶段。如 decode only 模型(如 GPT3、Llama 等)、encode-decode 模型(如 T5)的 decode 阶段,像 Bert 等非生成式模型并不适用。

KV Cache 原理

推理过程:给定一个问题,模型会输出一个回答。生成回答的过程每次只生成一个 token,输出的 token会和问题拼接在一起,再次作为输入传给模型,这样不断重复直至生成终止符停止。

GPT-4推理过程图

下图是Scaled dot-product attention 有无 KV Cache 优化计算过程的比较。一般情况下,在每个生成步骤中,都会重新计算之前token的注意力,而实际上我们只想计算新 token 的注意力。而采用 KV Cache 方法后,会把之前 Token的 KV 值存下来,新 token 预测时只需要从缓存中读取结果就可以了。

动图封面

Scaled dot-product attention 有无 KV Cache 比较,图片来源:https://medium.com/@joaolages/kv-caching-explained-276520203249

KV Cache 实现

huggingface的 transformer库已经实现了 KV cache,在推理时新增了past_key_values,设置 use_cache=True 或 config.use_cache=True 就可以了。

past_key_values( Cacheor tuple(tuple(torch.FloatTensor)), optional) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the past_key_valuesreturned by the model at a previous stage of decoding, when use_cache=Trueor config.use_cache=True.
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "Llama-2-7b-chat-hf"

device = "cuda:7" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
for use_cache in (True, False):
  times = []
  for _ in range(10):  # measuring 10 generations
      start = time.time()
      input = tokenizer("What is KV caching?", return_tensors="pt").to(device)
      outputs = model.generate(**input, use_cache=use_cache, max_new_tokens=1000, temperature=0.00001)
      times.append(time.time() - start)
  print(f"{'With' if use_cache else 'Without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

执行结果如下所示:

With KV caching: 8.946 +- 0.011 seconds
Without KV caching: 58.68 +- 0.012 seconds

从结果可以看出使用 KV Cache 方法进行大模型推理,推理速度增加了6.56倍,差异巨大。

KV Cache 显存占用分析

假设输入的序列长度是 𝑚,输出序列长度是 𝑛 , 𝑏 为数据批次大小, 𝑙 为层数, ℎ 为隐向量维度,以 FP16(2bytes) 来保存,那么 KV Cache的峰值显存占用大小为 𝑏(𝑚+𝑛)ℎ∗𝑙∗2∗2=4𝑏𝑙ℎ(𝑚+𝑛) ,第一个 2 代表 K、V,第二个 2 代表 2bytes。可见随着批次大小和长度的增加,KV Cache 的显存占用也会快速增大。

KV Cache 优化方法

这里主要介绍下 Multi Query Attention 和 Grouped Query Attention。

Multi Query Attention

Multi-query attention is identical except that the different heads share a single set of keys and values.

MQA 和 MHA的区别是:每个头共享相同的 K、V 权重而不共享Q的权重。

Grouped Query Attention

Grouped-query attention divides query heads intoG groups, each of which shares a single key head and value head.

分组注意力将查询头分为 G 组,每组共享一个键头和值头。GQA-G 是指有 G 组的分组查询。GQA-1,有一个组,因此有一个键头和值头,等同于 MQA,而 GQA-H,组数等于头数,等同于 MHA。

MHA、MQA、GQA比较可参考下图。

Multi Head Attention、Multi Query Attention、Grouped Query Attention 比较

使用MHA、MQA、GQA进行KV Cache 显存占用情况比较

MHA: 𝑏(𝑚+𝑛)ℎ∗𝑙∗2∗2=4𝑏𝑙ℎ(𝑚+𝑛) ;

MQA: 4𝑏𝑙ℎ(𝑚+𝑛)/𝐻 , 𝐻 代表头数;

GQA: 4𝑏𝑙ℎ(𝑚+𝑛)∗𝐺/𝐻 , 𝐻 代表头数, 𝐺 代表分组数;

MQA、GQA Huggingface 库都有实现,具体见llm_tutorial_optimization

如果觉得本文对您有帮助,麻烦点个小小的赞,谢谢大家啦~

Transformers KV Caching Explained
Huggingface-llama2
Fast Transformer Decoding: One Write-Head is All You Need
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

编辑于 2024-01-16 21:52・IP 属地江苏

### KV缓存的工作原理 KV缓存是一种基于键值对存储机制的高效数据访问方式。当应用程序请求特定的数据时,会通过唯一的键来查找相应的值[^2]。 具体来说: - **Get操作**:客户端发送带有指定键的获取请求给缓存服务器;如果该键存在于缓存中,则立即返回对应的值;否则执行未命中处理逻辑。 - **Put操作**:用于向缓存中插入新的键值对或者更新已有的条目。这通常发生在首次加载某个资源之后,以便后续对该资源的快速检索。 - **Evict操作**:为了保持有限内存空间的有效利用,系统可能会主动移除一些不常用或过期的数据项。这一过程可以通过多种淘汰算法实现,比如LRU(Least Recently Used),即最近最少使用的优先被淘汰。 - **TTL(Time To Live)**:为每一条记录设定生存时间,在这段时间过后自动清除该项,防止陈旧信息长期占用资源的同时也确保了数据的新鲜度。 ```python import time from collections import OrderedDict class SimpleKVCahce: def __init__(self, capacity=10): self.cache = OrderedDict() self.capacity = capacity def get(self, key): if key not in self.cache: return None value = self.cache.pop(key) self.cache[key] = value # Move to end (most recently used) return value def put(self, key, value): if key in self.cache: del self.cache[key] elif len(self.cache) >= self.capacity: oldest_key = next(iter(self.cache)) del self.cache[oldest_key] self.cache[key] = value def evict_least_recently_used_item(self): try: least_recently_used_key = next(iter(self.cache)) del self.cache[least_recently_used_key] except StopIteration as e: pass cache = SimpleKVCahce(capacity=3) # Example usage of the cache operations print(cache.get('a')) # Output: None cache.put('a', 'apple') cache.put('b', 'banana') print(cache.get('a')) # Output: apple time.sleep(5) # Simulate TTL expiration after some period ``` ### KV缓存的优势 引入KV缓存可以显著提升系统的整体性能和响应速度。主要体现在以下几个方面: - **减少延迟**:由于大多数情况下可以直接从高速缓存而非低速磁盘或其他外部源读取所需的信息,从而大大缩短了等待时间和提高了用户体验质量[^1]。 - **减轻负载压力**:对于频繁访问但变化较少的内容,将其保存于靠近应用层的位置能够有效缓解数据库等后端组件面临的高并发查询带来的巨大负担。 - **优化成本效益**:合理配置大小适中的缓存区域能够在不影响准确性的前提下节省大量的硬件投资以及运维开支[^3]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值