大模型推理--temperature、top_k、top_p理解

LLM推理的最后一步是要从众多候选token中选择一个输出,一般可以选择softmax概率最大的token输出,这样相同的输入都会获得确定的输出。不过,在很多情况下,最优输出不见得是最好的输出,尤其在当下LLM还不完美的情况下。为此我们需要让LLM的输出在保证靠谱的前提下尽可能多样,temperature、top_k、top_p这三个变量就是出于此目的设计出来的。当然很多博客中已经介绍了这三个变量的作用,但是很多人可能对细节还不了解,正好最近看了一个Python实现,借此给大家详细介绍一下这三个变量的作用。

1. 源码实现

我参考的源码是Freeze Omni这个项目中的post_process,并进行了简化,源码如下:

def do_sampling(logits: torch.Tensor, temperature=1.0, top_k=20, top_p=0.8):
    if temperature != 1.0:
        logits = logits / temperature

    probs = F.softmax(logits, dim=-1)

    if top_k != 0:
        top_k_probs, top_k_indices = torch.topk(probs, top_k)
        probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs)
        probs = probs / probs.sum()

    if top_p > 0.0:
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p #前0后1的数组
        sorted_indices_to_remove[0] = 0 #确保要保留一个

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        probs[indices_to_remove] = 0
        probs = probs / probs.sum()

    token_index = torch.multinomial(probs, 1)
return token_index

2. temperature作用

从源码中可以看到,涉及temperature的部分很简单,只有一句: logits = logits / temperature。虽然只有这简单的一句,但是内涵相当丰富。Temperature可以用来调节输入logits不同值之间的差值(比例关系不变),这个在后续求softmax的时候会有很大影响。可以设想一下,假设我们将temperature设的很大,经过这个除法之后logits的值就都非常小,在求softmax的时候各个位置的概率就会很均匀,在后续利用top_k和top_p进行采样时就会让采样具有多样性,结果也会更富有变化;如果temperature设的很小,则会导致logits更容易出现一个很大的值,在求softmax的时候就会导致某个token的概率很大,进而就会导致后续的采样结果更确定。

3. top_k的作用

top_k的作用就是从softmax的输出中选择前k大值,如果top_k等于0相当于softmax的所有输出均参与后面的采样。由于softmax的输出分布极为不均匀,往往只有一个或者几个较大的值,其他的值都接近于0。如果不设置top_k,一些概率很小的token也会参与采样,可能会导致结果过于发散(当然这种情况的概率也很低),所以一般都会设置top_k。
代码在求得top_k的概率和位置之后,将这k个值散布到和原始probs相同大小的零tensor中再归一化,这相当于把top_k之外的所有位置都置零。概率为0的位置在后续采样的时候就不会被选中。

4. top_p的作用

top_p相关的代码较长,作用解释起来稍微有点复杂。它是从累积概率的角度对softmax或者top_k之后的token位置进行进一步的筛选。代码首先会对probs进行降序排序,然后计算累积概率,找到累积概率首次超过top_p的位置,截断此后的所有概率。后续几步运算就是把累积概率超过top_p的所有位置置零,确保这些位置在采样时被排除。这样,top_p能在保持多样性的同时,避免极端小概率事件的影响,使结果更可控。通过合理设置top_k和top_p,能在精确性和多样性间找到平衡。例如,当top_p设为0.9时,意味着只保留累积概率达到90%的前几个token,其余的则被舍弃。这样既保证了采样结果的丰富性,又避免了低概率token的干扰,使得生成文本在可控范围内更具质量和连贯性。通过细致调整这两个参数,模型输出将更加符合预期,满足不同场景下的需求。
代码的最后利用multinomial来进行采样。在经过top_k和top_p的筛选之后,multinomial的输入是只有几个位置概率大于0,其他位置均为0的一个概率分布。
multinomial函数会根据调整后的概率分布进行随机采样,选择最可能的token,确保生成文本既符合预期又具备一定随机性,从而提升整体的自然性和可读性。

5. 性能优化

上述代码对理解这三个参数的含义比较好,但是在性能方面却存在不少问题,我们尝试对其进行优化。

5.1 优化一

原始代码在求得top_k之后会重新构造一个和原始logits一样大小的tensor,然后再进行排序。但实际上排完序之后的tensor前k个元素和top_k的结果一样,完全没必要构造新的tensor,我们可以直接利用top_k的结果求累积概率,这样我们就把sort的时间给省掉了。
利用top_k_probs对top_p进行筛选,找到首次超过top_p的位置进行截断归一化,相比之前对完整的probs进行筛选现在只需要在top_k个位置上进行筛选,速度提升了不少。按照该种优化方法实现的采样代码如下:

def do_sampling(logits: torch.Tensor, temperature=1.0, top_k=20, top_p=0.8):
    if temperature != 1.0:
        logits = logits / temperature

    probs = F.softmax(logits, dim=-1)

    if top_k != 0:
        top_k_probs, top_k_indices = torch.topk(probs, top_k)
		top_k_probs /= top_k_probs.sum()  

    if top_p > 0.0:
        cumulative_probs = torch.cumsum(top_k_probs, dim=-1)
        mask = cumulative_probs > top_p 
		mask[0] = 1 #确保要保留至少一个位置
		Probs = top_k_probs[mask]
        probs = probs / probs.sum()

    token_index = torch.multinomial(probs, 1)
	return top_k_indices[token_index]

上述代码在虽然优化了性能,但是也将top_k和top_p绑定在一起,大家酌情使用。

5.2 优化二

优化一相当于原始代码的等价变换,相同的输入得到相同的输出。还有一个略微改变输出结果的方法,主要是在softmax身上做文章。原始代码是对完整的logits求softmax,然后再求topk。可以改为先对logits求topk,再对筛选后的topk结果求softmax,这样top_k的复杂度没变,但是softmax的复杂度则大幅降低。这样计算可行的原因是利用了softmax的单调性。不过这样计算会导致softmax的输出与原始代码不一样,需要我们重新调整top_k和top_p的取值,以确保结果的可靠性。代码与优化一类似,在此不再给出代码。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值