用Transformers库实现基础的大模型文本生成以及KV cache注意事项

根据输入的prompt,生成一段指定长度的文字。Llama跑起来太慢了,这里用GPT-2作为列子。

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)

prompt_text = "This is a nice story that makes me"
max_gen_len = 9
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
prompt_len = input_ids.shape[-1]
print(f'length of prompt: {prompt_len}, length of generation: {max_gen_len}')

print('>>> Way 1: Use `model.generate()` to generate tokens with KV cache')
generated_ids = model.generate(input_ids, max_length=prompt_len+max_gen_len, pad_token_id=tokenizer.eos_token_id)
print('generated_ids:', generated_ids)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print('generated_text:', generated_text)

print('>>> Way 2: Use `for loop` to generate tokens with KV cache')
past_key_values = None
print('Prefill Stage..')
outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids = [pred_token_idx.item()]
print('Decoding/Generating Stage..')
for _ in range(max_gen_len - 1):
    outputs = model(input_ids=pred_token_idx, past_key_values=past_key_values, use_cache=True)
    past_key_values = outputs.past_key_values  # if use_cache=False, past_key_values=None
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    generated_ids.append(pred_token_idx.item())
print('generated_ids:', generated_ids)
generated_text = tokenizer.decode(torch.Tensor(generated_ids), skip_special_tokens=True)
print('generated_text:', prompt_text + generated_text)

这里提供了两种方法实现文本生成:

  • model.generate():给模型输入prompt,一次性得到所有输出的token,最方便的写法
  • for loop:这是StreamingLLM中给的代码例子,也揭示了自回归生成的原理。首先是prefill阶段,输入prompt,得到KV cache和生成的第一个token;然后是decoding/generating阶段,开始自回归生成token,每次生成的模型输入是当前新token和KV cache,每生成一个token都会自动更新KV cache
  • 当然也可以不用KV cache,那么每次的input/query就会越来越大,开销非常大(后面我有时间再把代码贴过来)

最终,可以看到两种方法生成的文本是一模一样的:

在这里插入图片描述

进一步探究自回归过程中维度的变化:

在这里插入图片描述

这个就是标准的自回归生成任务了,不管是GPT还是Llama,都是如此(至少PyTorch版本都是这样的,Flax版本的KV cache有点奇怪,用的lax.dynamic_update_slice(cached_key.value, key, indices),KV cache的维度并没有随着token的生成而增加…不太明白)。

2024.5.29补充:Flax版本的KV cache是用空间换时间,简单来说就是,在一开始就先初始化构造了足够大的KV cache(全0填充),多大呢?其实就是prompt_len+gen_len这么大。然后在decoding阶段更新KV cache时,不像Torch那样用concat更新,而是用add的方式将新的K和V向量加进去。这个就是Flax版本和Torch版本的本质区别了,他们达成的效果一致。但是如果是要做KV cache压缩优化,还是建议用Torch的concat版本。

官方代码(我定位到了KV cache更新的那一行):

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mr.zwX

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值