LLM模型 贪婪、温度、Top-k、核采样方式的区别—附代码与示例
在自然语言生成任务中,不同的采样技术用于从语言模型的输出中选择下一个生成的单词或词语。这些技术包括贪婪采样、温度采样、Top-k采样和核(Nucleus)采样。它们在选择生成单词的过程中有不同的策略,本文将介绍这四种采样方式的区别。
1. 贪婪采样 (Greedy Sampling)
贪婪采样是一种直接选择最可能的下一个词的策略。具体步骤为:
- 从模型输出的logits中,找到概率最大的那个词,直接选择它作为输出。
实现代码
class GreedySampler(Sampler):
def __call__(self, logits: torch.Tensor):
return logits.argmax(dim=-1)
优点
- 简单且计算效率高。
- 保证每一步选择最有可能的结果。
缺点
- 可能会导致生成的文本非常重复和缺乏多样性。
- 贪婪采样只关注当前概率最大的词,忽略了其他潜在的好选择,容易陷入局部最优解。
2. 带温度的采样 (Temperature Sampling)
温度采样通过引入一个温度参数来调整输出概率的分布,以控制生成文本的多样性。温度 T
的作用是平滑或锐化概率分布:
- 当
T = 1
时,采样为标准随机采样。 - 当
T < 1
时,概率分布变得更尖锐,模型更倾向于选择最可能的词。 - 当
T > 1
时,概率分布变得更加平滑,模型会更多地探索低概率的词。
实现代码
class TemperatureSampler(Sampler):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def __call__(self, logits: torch.Tensor):
dist = Categorical(logits=logits / self.temperature)
return dist.sample()
优点
- 提供了生成文本的多样性,尤其是在温度高时。
- 通过调整温度参数,可以控制探索(随机性)与利用(选择高概率词)之间的平衡。
缺点
- 温度的选择需要仔细调节,不同任务或场景下对温度的需求可能不同。
- 温度过低时,生成的文本趋向于贪婪采样;温度过高时,生成的文本可能过于随机。
3. Top-k采样
Top-k采样限制了每次生成时候的候选词数量,模型只会从概率前k个最高的词中进行采样,而忽略其他可能性较小的词。
实现代码
class TopKSampler(Sampler):
def __init__(self, k: int, sampler: Sampler):
self.k = k
self.sampler = sampler
def __call__(self, logits: torch.Tensor):
zeros = logits.new_ones(logits.shape) * float('-inf')
values, indices = torch.topk(logits, self.k, dim=-1)
zeros.scatter_(-1, indices, values)
return self.sampler(zeros)
优点
- 提供了对生成词汇的严格控制,避免生成概率非常低的词。
- 通过限制候选词的数量,避免了一些罕见或不合逻辑的词被选中。
缺点
- 需要设定一个合适的
k
值,如果k
值太小,生成的文本可能会缺乏多样性;如果k
值太大,则效果与标准采样相似。
4. 核采样 (Nucleus Sampling)
核采样是一种自适应的采样方法,它选择的候选词集合 V(p)
是满足累计概率和大于或等于给定阈值 p
的最小词汇子集。与Top-k采样不同,核采样的候选词数量不是固定的,而是基于累计概率动态确定的。
示例
假设同样的语境:“今天的天气很”,但这次我们将会有不同的词汇及其概率分布,我们也会使用不同的阈值 ( p ) 来展示如何动态确定选词数量。
模型预测的词汇概率
- 好:0.4
- 冷:0.3
- 热:0.2
- 潮湿:0.05
- 多变:0.03
- 干燥:0.02
排序与累积概率
按概率从高到低排序并计算累积概率:
- 好:0.4
- 冷:0.7 (0.4 + 0.3)
- 热:0.9 (0.7 + 0.2)
- 潮湿:0.95 (0.9 + 0.05)
- 多变:0.98 (0.95 + 0.03)
- 干燥:1.00 (0.98 + 0.02)
确定核集合
这次,我们将选择不同的阈值 ( p ) 来观察核集合如何变化:
- 当 ( p = 0.7 ):
- 核集合包括:“好”和“冷”,因为它们的累积概率首次超过 0.7。
- 当 ( p = 0.9 ):
- 核集合扩展到:“好”,“冷”,和“热”,因为它们的累积概率首次超过 0.9。
- 当 ( p = 0.95 ):
- 核集合进一步扩展到:“好”,“冷”,“热”和“潮湿”,因为这是累积概率首次超过 0.95。
抽样
在每种情况下,我们从对应的核集合中随机选取一个词作为下一个词。选择的范围和多样性取决于 ( p ) 值的大小,而词的数量是根据这个阈值动态确定的,不是固定的。
实现代码
class NucleusSampler(Sampler):
"""
## Nucleus 采样器
Nucleus 采样器根据给定的概率 p 选择词汇的一个子集,并从中进行采样。
"""
def __init__(self, p: float, sampler: Sampler):
"""
### 初始化
:param p: 要选择的令牌概率之和,即 p 值。
:param sampler: 用于从选定令牌中进行采样的采样器。
"""
# 保存 p 值
self.p = p
# 保存采样器
self.sampler = sampler
# 初始化 softmax 层,用于将 logits 转换为概率
self.softmax = nn.Softmax(dim=-1)
def __call__(self, logits: torch.Tensor):
"""
### 从 logits 中进行 Nucleus 采样
:param logits: 输入的 logits 张量,形状为 (batch_size, num_tokens)。
:return: 采样得到的令牌索引,形状为 (batch_size,)。
"""
# 获取概率 P(x_i | x_1:i-1)
probs = self.softmax(logits)
# 按降序对概率进行排序,并获取排序后的索引
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
# 按排序顺序获取概率的累积总和
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
# 找出累积总和小于 p 的令牌
nucleus = cum_sum_probs < self.p
# 在前面加一个 True,这样我们可以在累积概率小于 p 的最小令牌数量之后添加一个令牌
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
# 获取对数概率并掩盖非核部分
sorted_log_probs = torch.log(sorted_probs)
sorted_log_probs[~nucleus] = float('-inf')
# 使用采样器从排序后的对数概率中进行采样
sampled_sorted_indexes = self.sampler(sorted_log_probs)
# 获取实际的索引
res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
# 返回采样得到的令牌索引
return res.squeeze(-1)
优点
- 灵活性强,自动调整候选词集合,避免了固定的词数限制。
- 在生成文本时能够更好地平衡多样性与高概率词的利用,表现优于Top-k采样。
缺点
- 参数
p
的选择需要调节,不同任务可能需要不同的p
值。 - 计算复杂度较高,尤其是当处理较大的词汇表时。
总结
采样方法 | 优点 | 缺点 |
---|---|---|
贪婪采样 | 简单、高效,始终选择最有可能的词 | 文本生成可能单一,缺乏多样性 |
温度采样 | 通过调整温度控制多样性,适应性强 | 温度的调节需要谨慎,过高或过低的温度可能产生不理想的结果 |
Top-k采样 | 控制候选词数量,避免选择低概率词 | k 值选择需要调节,k 太小可能导致文本单一 |
核采样 | 动态选择候选词集合,更灵活,生成文本质量较高 | 参数 p 需要调节,计算复杂度较高 |
每种采样方式都有其适用的场景,根据具体的应用和对生成文本的要求,可以选择不同的采样策略。