Transformer推理加速方法-KV缓存(KV Cache)

文章介绍了在Transformer模型的推理过程中,通过缓存KV对(Key-Valuepairs)来减少重复计算,提高效率。在Encoder和Decoder中,分别详细阐述了KV缓存的应用,特别是在Decoder中如何结合先前时刻的缓存信息进行计算。此外,给出了实现缓存功能的伪代码示例,以及使用MultiheadAttention层的注意点。
摘要由CSDN通过智能技术生成

1. 使用KV缓存(KV Cache)

在推理进程中与训练不同,推理进行时上下文输入Encoder后计算出来的 K 和 V K和V KV 是固定不变的,对于这里的 K 和 V K和V KV 可以进行缓存后续复用;在Decoder中推理过程中,同样可以缓存计算出来的 K 和 V K和V KV 减少重复计算,这里注意在输入是am计算时,输入仍需要前面I的输入。

如下图:左边ATTN是Encoder,在T1时刻计算出来对应的 K 和 V K和V KV 并进行缓存,后续推理都不用再计算了;右边ATTN是Decoder,T2时刻通过输入的一个词计算出来 Q T 2 、 K T 2 、 V T 2 Q_{T2}、K_{T2}、V_{T2} QT2KT2VT2,但计算Decoder过程中需要之前时刻T1的所用 K 和 V K和V KV 向量。所以这里Decoder每次计算出来一组新的 K 和 V K和V KV 向量都跟之前向量一起进行缓存,后续也可以重复复用。

在这里插入图片描述

实现的伪码如下:

  • 推理过程中只用取最后一个词做为输入
q = q[-1:]
  • 当前输出只有一个值,在计算输出时把当前的output输出与之前输出cat到一起做为cache
output = torch.cat([cache, output], dim=0)
  • attention的调用如下,每次除了当前时刻的KV值,还加上之前的cache输出
output_t0 = attention(q_t0, k_t0, v_t0)
...
output_t1 = attention(q_t1, k_t1, v_t1, cache = output_t0)
...
output_t2 = attention(q_t2, k_t2, v_t2, cache = output_t1)
... etc
  • attention中的实现如下:
self.attn_head = nn.MultiheadAttention(256, 8)
def attention(q, k ,v, cache=None):
    if cache is not None:
       q = q[-1:]
    out = self.attn_head(q, k, v, attn_mask=triangular_mask)
    if cache is not None:
       out = torch.cat([cache, out], dim=0)
    return out

2. 参考

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MLTalks

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值