LLM的投机采样

推荐这个代码,简单好学,两个小时就能看完

基本

transformers的model模型加载后,执行

output=model(input_ids)

而不是

output=model.chat(**input)
output=model.generate(**input)

你将得到包括logits和past_key_values的dict。
model()不负责更远的预测,你给10个token,它只给你中间层的kv和最后10个对应位置的logits,logits是softmax之前的score。
显然,一次model()可以帮你获得下一个token的输出,即在logits的最后一个位置。

投机采样

基本思路很简单,假如你的input token长度10,small模型往前预测3个,也就是执行model() 3次,注意,你必须手动保存kv,即model()返回的past_key_values,并使用

model(input_ids, past_key_values, use_cache=True)

use_cache保证model使用past_key_values,而不是重新计算。
这一部分在链接的KVCacheModel类里进行kv-cache管理。

然后把small模型的输出token ids给target model做1次model(input_ids),以获得token们的logits。

此时就可以看看采样是否合格了。

检测合格

合格标准是:

  1. target model在位置n的选定token概率比small token的概率大,则一定合格
  2. 反之,假设token(target_model) / token(small_model) =0.5 < 1,则进行一次0~1的均匀采样
r = torch.rand(1)

r>0.5了,就合格。显然,small_model的token越自信,且target_model越不自信,则越不可能选定。

打扫战场

如果全部接收了,此时有个小trick。因为我们必须把投机采样扔进target model计算logits,这本质是在预测下一个token,因此如果全部接收了,我们不希望浪费这个由target model生成的token,所以要做保存处理。

如果部分接收了甚至全部拒绝了,这两种情况可以一起处理,则保存接收的部分,此时仍然可以使用上面提到的trick,从target model产生拒绝的位置采样下一个token,此时这个trick是必须做的,否则如果全部拒绝,你又不使用target的预测,你就永远不会往前移动了。

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值