【大模型理论篇】Transformer KV Cache原理深入浅出

1. 前言

         Transformer KV Cache(键值缓存)是一种应用于仅解码器Transformer中的工程优化技术,通过增加内存利用率来减少计算量。在讨论KV Cache之前,我想起之前与行业内的大佬交流时提到,大模型时代的到来要求我们放下既有的思维模式,才能更好地接受并理解大模型的理念和实践。对此我部分赞同。从宏观角度看,大模型的表现确实是颠覆性的,难以用传统的认知框架来解释。我曾在《压缩泛化-对大语言模型智能涌现的理解》一文中探讨过,大模型正在做的“压缩即智能”这件事,是相当伟大的创新。然而,从技术方案和技术细节的角度来看,大模型依然在沿用大量已有的方法,并不是凭空产生的。因此,仍然可以从原有的知识体系出发,迁移学习大模型的相关知识。这两种观点并不矛盾,只是视角和立场不同,从整体上看是统一的,因此不必对大模型产生畏难情绪或认知偏差。

        其实,从事过隐私计算相关算法实践的人会理解,隐私计算同样需要用新的认知方式去审视,不能完全套用传统的数据安全保护方法,否则难以抓住其核心。然而,当我们从算法实现的细节来看待隐私计算时,会发现它其实将现有的机器学习、分布式计算、密码学等知识在隐私计算的层面实现了统一。

        回到本文讨论的KV Cache,实话实说,其技术实现和原理都相对简单,有一定大模型工程优化经验的人一般都能想到这种方法。因此,在学习和理解大模型时,保持平和的心态至关重要。

2. KV Cache算法原理介绍

2.1 知识背景

        首先回顾GPT中Transformer的结构,可以回看《GPT系列预训练模型原理讲解》

  • Transformer的输入是一个序列的tokens(或这些tokens的批处理)。
  • 在GPT Transformer中,每一块都有一个注意力层和一个前馈层。
  • Transformer中的几乎每一层/操作都是基于每个token的,除了注意力层(仅其中的一部分)。
  • 注意力层有多个头,通常情况下,d_model = d_head * n_heads。

在生成推理阶段:每次前向传递都会生成一个token,将其附加到输入中,然后将输入(现在的序列长度+1)再次传递回模型进行下一次前向传递。然而,这种朴素的方法显然非常低效,因为:

  • 它重新计算了先前的键、值和注意力行。
  • 网络中基于每个token的其他部分也会再次被使用,浪费了计算资源。

2.2 自注意力层计算量的二次增长问题分析

        首先分析在Transformer模型中的多头注意力(MHA)层的处理过程,假设只处理一个长度为t的序列(即batchsize大小为1):

  • 在整个过程中,输入序列中的每个token都由一个稠密向量表示。
  • 注意力层的输入是一系列稠密向量,每个输入token对应一个,由前一解码器块生成。
  • 对于每个输入向量,注意力层生成一个相同维度的输出稠密向量。

        考虑单个注意力头的情况:

  • 首先,使用三个不同的投影为每个输入向量生成三个低维稠密向量:查询向量(query)、键向量(key)和值向量(value)。因此会有t个查询向量、t个键向量和t个值向量。
  • 对于每个查询向量,生成的输出向量等于所有值向量的线性组合,线性组合的系数为注意力得分。对于每个查询,对应的输出向量是这些值向量的注意力加权平均。注意力得分是通过查询向量与每个键向量的点积计算得到的。通过这种方式,为序列中的每个token生成了一个包含其他tokens信息的上下文表示。
  • 在自回归解码的情况下,不能使用所有可能的值向量来为特定查询生成输出表示。实际上,在计算与特定token相关的查询输出时,不能使用序列中后续token的值向量。这种限制通过一种Mask掩码的方式来实现,该技术本质上将后续token的注意力得分设置为零。
  • 最后,每个注意力头的输出被连接起来,并通过最后的线性变换得到最终输出。

注意力计算的二次增长

        计算注意力得分所需的浮点运算数(FLOPs)。对于给定的注意力头,对于一个批量大小为batch且总长度为t的序列(包括提示和生成的完成序列),注意力得分矩阵通过一个形状为(t, d_head)的查询张量与一个转置的形状为(d_head, t)的键张量相乘而生成。

        单次矩阵乘法需要多少FLOPs?形状为(m,n)的矩阵与另一个形状为(n,p)的矩阵乘法大约涉及2mnp次操作,这里加法次数需要注意一下。

M^{m, n} \times N^{n, p} \Rightarrow 2mnp       

        在本例子中,单头单序列的注意力得分计算大约需要2 * d_{head} * t^2次FLOPs。总体而言,注意力得分计算所需的FLOPs为2 \times batch \times n_{layers} \times n_{head} \times d_{head} \times t^2=2 \times batch \times n_{layers} \times d_{model} \times t^2。显然,计算量随t的二次增长显而易见。

2.3 KV缓存是什么

        接下来我们引出本文的焦点:KV Cache。 

        先来看一下,当一个新token被添加到输入x中时:

  • 在q、k、v中各增加一行【4】。

  • attention矩阵(att)的尺寸从(T, T)变为(T+1, T+1),即新增了一行和一列。
  • 新增的这一行表示新token与所有之前token之间的注意力(attention)。
  • 由于att矩阵被掩码为下三角矩阵,新增的这一列除了最后一行的值之外,其余位置都是零。

  • att矩阵中的新增行将导致在v中产生一个新的输出,这个输出是att矩阵的最后一行与v的所有列相乘得到的结果,即att[-1, :] @ v

        大家是否发现了规律,当新添加一个token后,查询矩阵、键矩阵和值矩阵只是多了一行,而之前的行不受影响。注意力矩阵也只是多了一行,因此,输出也只有一行额外的行。由于自注意力掩码的作用,所有先前的行都不受影响。

        【2】中给出了可视化的例子:

        以下是将最初的两个token传递给单个注意力头后,查询(Query)、键(Key)、值(Value)、注意力矩阵(Attention Matrix)和输出值(Output Values)的样子:(这里的头维度 d_head = 4)。

        现在再添加一个token后:

        即使再添加一个新的token,attention weights的前面几行也都不受影响。

因此:
        现在可以理解 KV 缓存是如何工作的:对于生成的每个新token,不需要传入整个序列,因此可以避免重新计算整个注意力矩阵。只需要以下面的方式对新token进行操作:

  1. 仅为新token计算新的 q、k、v 行。
  2. 新的 q 行将立即被使用。(这也解释了为什么没有查询缓存的原因)
  3. 将新的键、值条目附加到现有的 K、V 缓存中。
  4. 通过新的 q 行和 k_cache 的转置进行矩阵向量乘法来计算新的注意力行。
  5. 通过新的注意力行和 v_cache 的转置进行矩阵向量乘法来计算新的 v 行。
  6. 输出(仅针对最新标记)被传递到下一层。

        这是一种通过增加内存使用来节省重复计算的权衡。

2.4 KV Cache的内存占用分析

        既然KV Cache时通过增加内存来降低重复计算的量,那么有必要分析一下内存占用大小,其实后续针对KV Cache又提出了很多节省内存占用的方法,比如【5,6,7】,有兴趣可以看下优化缓存的分析。本文先主要关注内存占用的分析。

        假设模型的参数配置信息如下:

  • Transformer 中有 n_layers 个层块。
  • 每个层块中有一个多头注意力层。
  • 每个多头注意力层有 n_heads 个注意力头,每个头的 kv 的尺寸为 d_head
  • 需要为 KV 都缓存一份。
  • 最大上下文长度为 n_context
  • 精度为 n_bytes,例如对于 FP32 是 4。
  • 推理时的批量大小为 batch_size

        那么总的内存大小:

kv_cache_size = n_layers * n_heads * 2 * n_context * d_head * n_bytes * batch_size

        简化后:

kv_cache_size = 2 * n_bytes * n_layers * d_model * n_context * batch_size

例如,针对 OPT-30B 模型的内存量级计算【4】:

  • n_bytes = 2(FP16)
  • n_layers = 48
  • d_model = 7168
  • n_context = 1024
  • batch = 128

计算结果为 180,388,626,432 字节,约为 180 GB。

2.5 KV缓存实现线性注意力增长

        那么注意力模块的计算量随着缓存的引入发生了什么变化?

        转置后的键张量仍然是形状(t, d_head),但查询张量现在的形状是(d_head, 1)。因此,单头单序列的注意力得分计算现在需要2 \times d_{head} \times t 次FLOPs,而总体的注意力计算需要2 \times batch \times n_{layers} \times d_{model} \times t次FLOPs。注意力计算现在随序列总长度线性增长。

2.6 结论

        注意力得分的计算量随着序列总长度呈二次增长。由于注意力计算中的掩码机制,在每次生成步骤中,实际上可以避免为过去的token重新计算键和值向量。每次计算新的键和值向量时,可以将它们缓存到GPU内存中,以便在后续的迭代中重复使用,从而避免重新计算。引入这种优化策略后,注意力机制的FLOPs随总序列长度实现了线性增长。

3. 参考材料

【1】LLM Inference Series: 3. KV caching explained

【2】What is the Transformer KV Cache?

【3】Transformers KV Caching Explained

【4】The KV Cache: Memory Usage in Transformers

【5】LLM profiling guides KV cache optimization

【6】vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention

【7】PagedAttention

### Transformer KVCache INT8 Quantization Implementation and Optimization In the context of optimizing transformer models, particularly focusing on the key-value (KV) cache mechanism within these architectures, applying INT8 quantization can significantly enhance computational efficiency while maintaining model accuracy. The transformation to lower precision arithmetic is especially beneficial for operations where high numerical resolution is less critical. When implementing INT8 quantization specifically for the KV Cache in a transformer architecture: - **Quantization Strategy**: For optimal performance enhancement, it's crucial that layers which do not commute well with quantize/dequantize (Q/DQ) operations are targeted for conversion into INT8 format[^1]. This means ensuring both inputs and outputs remain as INT8 throughout such processes. - **Fusion Techniques**: To further optimize this process, fusion techniques should be employed so that after fusing certain operations like addition (`Add`), their input and output types become strictly INT8. Such an approach ensures maximum performance gains by minimizing floating-point computations during inference phases. Additionally, when considering broader aspects beyond just the KV Cache component: - **Comprehensive Model Compression Methods**: Beyond simple layer-wise or operation-specific optimizations, comprehensive strategies encompassing various forms of compression—such as pruning, low-rank factorization, knowledge distillation alongside weight sharing—are essential components in achieving efficient deployment without sacrificing too much predictive power[^2]. To illustrate how one might implement these principles programmatically using Python code tailored towards TensorFlow Lite framework—which supports post-training dynamic range quantization—we provide below a simplified example demonstrating basic steps involved in converting parts of a pre-trained transformer model suitable for hardware acceleration through INT8 quantized weights and activations: ```python import tensorflow as tf converter = tf.lite.TFLiteConverter.from_saved_model('path_to_transformer') converter.optimizations = [tf.lite.Optimize.DEFAULT] def representative_dataset_gen(): for _ in range(100): yield [np.random.uniform(-1., 1., size=(batch_size, seq_length)).astype(np.float32)] converter.representative_dataset = representative_dataset_gen converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_quant_model = converter.convert() with open('transformer_int8.tflite', 'wb') as f: f.write(tflite_quant_model) ``` This script demonstrates setting up a TFLite converter configured for default optimization levels but explicitly targeting INT8 support for both inputs and outputs. A generator function provides sample data points necessary for calibrating the quantizer parameters accurately before exporting the final optimized version ready for deployment.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

源泉的小广场

感谢大佬的支持和鼓励!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值