KV Cache(键-值缓存)是一种在大模型推理中广泛应用的优化技术,其核心思想是利用缓存 key 和 value 来避免重复计算,从而提高推理效率。代价是显存占用会增加。
核心思想
在自注意力层的计算中,对于给定的输入序列,模型会计算每个token的key和value向量。这些向量的值在序列生成过程中是不变的。因此,通过缓存这些向量,可以避免在每次生成新token时重复计算,只需计算新token的query向量,并使用缓存的key/value向量进行自注意力计算 。
具体来说,decoder一次推理只输出一个token,输出token会与输入tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。
在上面的推理过程中,每 step 内,输入一个 token序列,经过Embedding层将输入token序列变为一个三维张量[b, s, h],经过一通计算,最后经logits层将计算结果映射至词表空间,输出张量维度为[b, s, vocab_size]。
当前轮输出token与输入tokens拼接,并作为下一轮的输入tokens,反复多次。可以看出第𝑖+1 轮输入数据只比第𝑖轮输入数据新增了一个token,其他全部相同!因此第𝑖+1轮推理时必然包含了第 𝑖 轮的部分计算。KV Cache的出发点就在这里,缓存当前轮可重复利用的计算结果,下一轮计算时直接读取缓存结果。
两阶段推理
推理过程的两个阶段:
-
第一阶段:在第一次迭代时,KV Cache为空,需要为所有输入的token计算key、value和query向量,并将key和value缓存起来。
-
第二阶段:在后续迭代中,只需要为新的token计算key、value和query,然后更新KV Cache
class Attention(nn.Module):` `"""Multi-headed attention from 'Attention Is All You Need' paper"""`` ` `def __init__(self, config: BaiChuanConfig):` `super().__init__()` `self.config = config` `self.hidden_size = config.hidden_size` `self.num_heads = config.num_attention_heads` `self.head_dim = self.hidden_size // self.num_heads` `self.max_position_embeddings = config.max_position_embeddings`` ` `if (self.head_dim * self.num_heads) != self.hidden_size:` `raise ValueError(` ``f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"`` ``f" and `num_heads`: {self.num_heads})."`` `)` `# self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)` `# self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)` `# self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)` `self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)` `self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)` `self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)`` ` `def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):` `return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()`` ` `def forward(` `self,` `hidden_states: torch.Tensor,` `attention_mask: Optional[torch.Tensor] = None,` `position_ids: Optional[torch.LongTensor] = None,` `past_key_value: Optional[Tuple[torch.Tensor]] = None,` `output_attentions: bool = False,` `use_cache: bool = False,` `) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:` `bsz, q_len, _ = hidden_states.size() # 开启kv cache 解码阶段 q_len 为 1` ` proj = self.W_pack(hidden_states)` `proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)` `query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,` `2) # batch_size x source_len x hidden_size` `key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,` `2) # batch_size x target_len x head_size` `value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,` `2) # batch_size x source_len x hidden_size` ` ` `# query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)` `# key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)` `# value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)`` ` `kv_seq_len = key_states.shape[-2]` ` if past_key_value is not None:` `kv_seq_len += past_key_value[0].shape[-2] # 更新 kv_seq_len` `print("kv_seq_len += past_key_value[0].shape[-2]")` `print(kv_seq_len)` `cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)` `query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)` ` # [bsz, nh, t, hd]`` ` `if past_key_value is not None: # 合并` `# reuse k, v, self_attention` `key_states = torch.cat([past_key_value[0], key_states], dim=2)` `value_states = torch.cat([past_key_value[1], value_states], dim=2)` ` ` `past_key_value = (key_states, value_states) if use_cache else None` ` attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)` ` ` `if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):` `raise ValueError(` `f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"` `f" {attn_weights.size()}"` `)`` ` `if attention_mask is not None:` `if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):` `raise ValueError(` `f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"` `)` `attn_weights = attn_weights + attention_mask` `attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))`` ` `# upcast attention to fp32` `attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)` `attn_output = torch.matmul(attn_weights, value_states)`` ` `if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):` `raise ValueError(` ``f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"`` `f" {attn_output.size()}"` `)`` ` `attn_output = attn_output.transpose(1, 2)` `attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)` ` ` `attn_output = self.o_proj(attn_output)`` ` `if not output_attentions:` `attn_weights = None`` ` `return attn_output, attn_weights, past_key_value
显存占用
使用KV Cache的代价是显存占用会增加,因为它需要存储历史全量的KV信息。显存占用的计算公式为:2 * hidden_size * num_layers * decoder_length,而未使用KV Cache时的显存占用为:2 * hidden_size * 1 * decoder_length
如何学习大模型 AI ?
由于新岗位的生产效率,要优于被取代岗位的生产效率,所以实际上整个社会的生产效率是提升的。
但是具体到个人,只能说是:
“最先掌握AI的人,将会比较晚掌握AI的人有竞争优势”。
这句话,放在计算机、互联网、移动互联网的开局时期,都是一样的道理。
我在一线互联网企业工作十余年里,指导过不少同行后辈。帮助很多人得到了学习和成长。
我意识到有很多经验和知识值得分享给大家,也可以通过我们的能力和经验解答大家在人工智能学习中的很多困惑,所以在工作繁忙的情况下还是坚持各种整理和分享。但苦于知识传播途径有限,很多互联网行业朋友无法获得正确的资料得到学习提升,故此将并将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。
第一阶段(10天):初阶应用
该阶段让大家对大模型 AI有一个最前沿的认识,对大模型 AI 的理解超过 95% 的人,可以在相关讨论时发表高级、不跟风、又接地气的见解,别人只会和 AI 聊天,而你能调教 AI,并能用代码将大模型和业务衔接。
- 大模型 AI 能干什么?
- 大模型是怎样获得「智能」的?
- 用好 AI 的核心心法
- 大模型应用业务架构
- 大模型应用技术架构
- 代码示例:向 GPT-3.5 灌入新知识
- 提示工程的意义和核心思想
- Prompt 典型构成
- 指令调优方法论
- 思维链和思维树
- Prompt 攻击和防范
- …
第二阶段(30天):高阶应用
该阶段我们正式进入大模型 AI 进阶实战学习,学会构造私有知识库,扩展 AI 的能力。快速开发一个完整的基于 agent 对话机器人。掌握功能最强的大模型开发框架,抓住最新的技术进展,适合 Python 和 JavaScript 程序员。
- 为什么要做 RAG
- 搭建一个简单的 ChatPDF
- 检索的基础概念
- 什么是向量表示(Embeddings)
- 向量数据库与向量检索
- 基于向量检索的 RAG
- 搭建 RAG 系统的扩展知识
- 混合检索与 RAG-Fusion 简介
- 向量模型本地部署
- …
第三阶段(30天):模型训练
恭喜你,如果学到这里,你基本可以找到一份大模型 AI相关的工作,自己也能训练 GPT 了!通过微调,训练自己的垂直大模型,能独立训练开源多模态大模型,掌握更多技术方案。
到此为止,大概2个月的时间。你已经成为了一名“AI小子”。那么你还想往下探索吗?
- 为什么要做 RAG
- 什么是模型
- 什么是模型训练
- 求解器 & 损失函数简介
- 小实验2:手写一个简单的神经网络并训练它
- 什么是训练/预训练/微调/轻量化微调
- Transformer结构简介
- 轻量化微调
- 实验数据集的构建
- …
第四阶段(20天):商业闭环
对全球大模型从性能、吞吐量、成本等方面有一定的认知,可以在云端和本地等多种环境下部署大模型,找到适合自己的项目/创业方向,做一名被 AI 武装的产品经理。
- 硬件选型
- 带你了解全球大模型
- 使用国产大模型服务
- 搭建 OpenAI 代理
- 热身:基于阿里云 PAI 部署 Stable Diffusion
- 在本地计算机运行大模型
- 大模型的私有化部署
- 基于 vLLM 部署大模型
- 案例:如何优雅地在阿里云私有部署开源大模型
- 部署一套开源 LLM 项目
- 内容安全
- 互联网信息服务算法备案
- …
学习是一个过程,只要学习就会有挑战。天道酬勤,你越努力,就会成为越优秀的自己。
如果你能在15天内完成所有的任务,那你堪称天才。然而,如果你能完成 60-70% 的内容,你就已经开始具备成为一名大模型 AI 的正确特征了。