如何计算kv cache的缓存大小

符号定义

首先,定义一些符号:

( B ):批大小(Batch Size)
( L ):序列长度(Sequence Length)
( N ):Transformer 层数(Number of Transformer Layers)
( H ):注意力头数(Number of Attention Heads)
( D ):每个注意力头的维度(Dimension per Head),即 ( D = Hidden Size / H D = \text{Hidden Size} / H D=Hidden Size/H)
( S ):数据类型大小(Size of Data Type),以字节为单位。例如:
FP32(32位浮点数):( S = 4 ) 字节
FP16(16位浮点数):( S = 2 ) 字节

KV 缓存的内存计算

对于每一层的多头注意力机制,我们需要存储 键(Key)值(Value) 的缓存。对于每一层,键和值的缓存大小计算如下:

键缓存(Key Cache)大小:

Size Key = B × H × L × D × S \text{Size}_{\text{Key}} = B \times H \times L \times D \times S SizeKey=B×H×L×D×S

值缓存(Value Cache)大小:

Size Value = B × H × L × D × S \text{Size}_{\text{Value}} = B \times H \times L \times D \times S SizeValue=B×H×L×D×S

因此,每一层的 KV 缓存总大小为:

SizeKV per layer = SizeKey + Size Value = 2 × B × H × L × D × S \text{Size}{\text{KV per layer}} = \text{Size}{\text{Key}} + \text{Size}_{\text{Value}} = 2 \times B \times H \times L \times D \times S SizeKV per layer=SizeKey+SizeValue=2×B×H×L×D×S

由于模型有 ( N ) 层,因此 总的 KV 缓存大小为:

Total SizeKV = N × SizeKV per layer = 2 × B × H × L × D × N × S \text{Total Size}{\text{KV}} = N \times \text{Size}{\text{KV per layer}} = 2 \times B \times H \times L \times D \times N \times S Total SizeKV=N×SizeKV per layer=2×B×H×L×D×N×S

具体示例计算

假设以下参数:

批大小:( B = 1 )

序列长度:( L = 1 ) (即 token 数为 1)

层数:( N = 12 ) (例如,一个小型的 Transformer)

隐藏层尺寸:( Hidden Size = 768 \text{Hidden Size} = 768 Hidden Size=768 )

注意力头数:( H = 12 )

每个头的维度:

D = Hidden Size H = 768 12 = 64 D = \frac{\text{Hidden Size}}{H} = \frac{768}{12} = 64 D=HHidden Size=12768=64

数据类型:FP32,( S = 4 ) 字节

现在,我们计算每一层的 KV 缓存大小:

Size KV per layer = 2 × B × H × L × D × S   = 2 × 1 × 12 × 1 × 64 × 4   = 2 × 1 × 12 × 1 × 64 × 4   = 2 × 12 × 64 × 4   = 2 × 12 × 64 × 4   = 6144  字节 \begin{align*} \text{Size}_{\text{KV per layer}} &= 2 \times B \times H \times L \times D \times S \ &= 2 \times 1 \times 12 \times 1 \times 64 \times 4 \ &= 2 \times 1 \times 12 \times 1 \times 64 \times 4 \ &= 2 \times 12 \times 64 \times 4 \ &= 2 \times 12 \times 64 \times 4 \ &= 6144\ \text{字节} \end{align*} SizeKV per layer=2×B×H×L×D×S =2×1×12×1×64×4 =2×1×12×1×64×4 =2×12×64×4 =2×12×64×4 =6144 字节

总的 KV 缓存大小:

Total SizeKV = N × SizeKV per layer  = 12 × 6144   = 73728  字节 \begin{align*} \text{Total Size}{\text{KV}} &= N \times \text{Size}{\text{KV per layer}} \ &= 12 \times 6144 \ &= 73728\ \text{字节} \end{align*} Total SizeKV=N×SizeKV per layer =12×6144 =73728 字节

即大约 72 KB。

虽然这个数字看起来不大,但在大型模型中,参数会显著增大。例如,考虑一个具有以下参数的大型模型:

层数:( N = 96 )

隐藏层尺寸:( Hidden Size = 12288 \text{Hidden Size} = 12288 Hidden Size=12288)

注意力头数:( H = 96 )

每个头的维度:

D = 12288 96 = 128 D = \frac{12288}{96} = 128 D=9612288=128

数据类型:FP16,( S = 2 ) 字节

计算每一层的 KV 缓存大小:

Size KV per layer = 2 × B × H × L × D × S   = 2 × 1 × 96 × 1 × 128 × 2   = 2 × 96 × 128 × 2   = 2 × 96 × 128 × 2   = 49 , 152  字节 \begin{align*} \text{Size}_{\text{KV per layer}} &= 2 \times B \times H \times L \times D \times S \ &= 2 \times 1 \times 96 \times 1 \times 128 \times 2 \ &= 2 \times 96 \times 128 \times 2 \ &= 2 \times 96 \times 128 \times 2 \ &= 49,152\ \text{字节} \end{align*} SizeKV per layer=2×B×H×L×D×S =2×1×96×1×128×2 =2×96×128×2 =2×96×128×2 =49,152 字节

总的 KV 缓存大小:

Total SizeKV = N × SizeKV per layer  = 96 × 49 , 152   = 4 , 719 , 616  字节 \begin{align*} \text{Total Size}{\text{KV}} &= N \times \text{Size}{\text{KV per layer}} \ &= 96 \times 49,152 \ &= 4,719,616\ \text{字节} \end{align*} Total SizeKV=N×SizeKV per layer =96×49,152 =4,719,616 字节

即大约 4.5 MB。

注意事项

模型规模的影响: 可以看到,随着 层数 ( N )、隐藏层尺寸 和 注意力头数 ( H ) 的增加,KV 缓存的内存需求会显著增长。

序列长度的影响: 虽然在 ( L = 1 ) 时,序列长度对内存影响较小,但在生成长序列时,( L ) 会增加,导致 KV 缓存内存占用线性增长。

数据类型的影响: 使用 FP16 可以将内存占用减少一半,但对于大型模型,内存需求仍然很高。

总结

即使 token 数为 1,由于模型的层数、注意力头数、每个头的维度等参数较大,KV 缓存仍然需要消耗较大的内存。

通过以上公式,可以直观地看到各个参数对 KV 缓存内存占用的影响,从而理解为什么在处理单个 token 时仍需要大的内存。

优化建议

减少模型规模: 降低 ( N )、( H ) 或 ( D ) 的值。
使用半精度: 采用 FP16 或更低精度的数据类型。
批量大小优化: 确保 ( B ) 仅为需要的最小值。
序列长度控制: 在可能的情况下,限制生成序列的最大长度 ( L )。

### Transformer 模型中键值 (KV) 缓存的实现 在 Transformer 模型中,KV 缓存用于存储先前计算得到的 Key 和 Value 向量,在解码阶段可以重复利用这些向量来减少不必要的重新计算。这种机制显著降低了推理过程中的计算开销并提高了效率。 对于 KV 缓存的具体实现方式如下: 当处理序列数据时,每次仅需更新最新时间步对应的 Query 向量,并从缓存读取之前所有的时间步所对应 Keys 和 Values 进行注意力权重计算[^1]。 ```python def forward(self, query, key=None, value=None, cache=None): if cache is not None and 'key' in cache and 'value' in cache: # 使用已有的kv缓存 key = torch.cat([cache['key'], new_key], dim=1) value = torch.cat([cache['value'], new_value], dim=1) attn_output_weights = self.attention(query=query, key=key, value=value) if cache is not None: cache.update({'key': key, 'value': value}) return attn_output_weights ``` 通过上述代码片段展示了如何在一个典型的多头自注意模块内集成 KV 缓存逻辑。这里假设 `new_key` 和 `new_value` 是当前输入产生的新 K/V 对;而如果存在有效的缓存,则会先将其与新的 K/V 结合再执行后续操作[^3]。 ### 优化方法 为了进一步提升性能,还可以采取一些额外措施来进行优化: - **分片技术**: 将大型张量分割成更小的部分分别保存到不同设备上,从而有效缓解单个 GPU 的显存压力。 - **压缩策略**: 应用量化或其他形式的数据表示简化手段降低所需占用的空间大小。 - **异构硬件支持**: 利用 CPU/GPU 协同工作模式加速特定部分运算的同时保持较低功耗水平[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值