基本
transformers的model模型加载后,执行
output=model(input_ids)
而不是
output=model.chat(**input)
output=model.generate(**input)
你将得到包括logits和past_key_values的dict。
model()不负责更远的预测,你给10个token,它只给你中间层的kv和最后10个对应位置的logits,logits是softmax之前的score。
显然,一次model()可以帮你获得下一个token的输出,即在logits的最后一个位置。
投机采样
基本思路很简单,假如你的input token长度10,small模型往前预测3个,也就是执行model() 3次,注意,你必须手动保存kv,即model()返回的past_key_values,并使用
model(input_ids, past_key_values, use_cache=True)
use_cache保证model使用past_key_values,而不是重新计算。
这一部分在链接的KVCacheModel类里进行kv-cache管理。
然后把small模型的输出token ids给target model做1次model(input_ids),以获得token们的logits。
此时就可以看看采样是否合格了。
检测合格
合格标准是:
- target model在位置n的选定token概率比small token的概率大,则一定合格
- 反之,假设token(target_model) / token(small_model) =0.5 < 1,则进行一次0~1的均匀采样
r = torch.rand(1)
r>0.5了,就合格。显然,small_model的token越自信,且target_model越不自信,则越不可能选定。
打扫战场
如果全部接收了,此时有个小trick。因为我们必须把投机采样扔进target model计算logits,这本质是在预测下一个token,因此如果全部接收了,我们不希望浪费这个由target model生成的token,所以要做保存处理。
如果部分接收了甚至全部拒绝了,这两种情况可以一起处理,则保存接收的部分,此时仍然可以使用上面提到的trick,从target model产生拒绝的位置采样下一个token,此时这个trick是必须做的,否则如果全部拒绝,你又不使用target的预测,你就永远不会往前移动了。