束搜索(Beam Search)与组束搜索(Group Beam Search)介绍
束搜索(Beam Search)
束搜索是一种用于序列生成任务的解码算法,广泛应用于自然语言处理领域,如机器翻译、语音识别等。它的工作原理是:
- 在每个解码步骤中,不是简单地选择最有可能的下一个词(这被称为贪心搜索),而是考虑前
num_beams
个最有可能的候选词。 - 这些候选词形成多个可能的序列路径,这些路径在后续步骤中继续扩展。
- 最终,当达到最大长度或遇到结束标记时,从所有路径中选择得分最高的序列作为最终输出。
优点:
- 相比于贪心搜索,它更有可能找到全局最优解。
- 通过调整束宽(
num_beams
)可以在计算成本和结果质量之间取得平衡。
缺点:
- 计算复杂度随着束宽的增加而显著提升。
- 可能偏向较短的句子,因为较长的句子概率连乘后会变得非常小。
组束搜索(Group Beam Search)
组束搜索是束搜索的一种变体,旨在提高生成结果的多样性。它的主要特点包括:
- 将总束分为若干个小组(由
num_beam_groups
指定),每个小组独立进行束搜索。 - 使用
diversity_penalty
参数来惩罚来自同一组的重复词汇,从而鼓励不同组之间的差异性。 - 每个组内的束仍然遵循标准束搜索的规则,但是由于引入了多样性惩罚,不同组的结果更加多样化。
优点:
- 提高了生成文本的多样性,这对于某些应用场景(如多答案问题回答、创意写作)非常重要。
- 可以帮助避免模型生成过于相似的输出,特别是在需要探索多种可能性的情况下。
缺点:
- 需要更多的计算资源,因为它实际上是在并行运行多个束搜索。
- 结果的质量可能不如单个大束宽的束搜索稳定,取决于如何设置多样性和束组数量。
异同点
特征 | 束搜索 (Beam Search) | 组束搜索 (Group Beam Search) |
---|---|---|
目的 | 找到最有可能的单一序列 | 增加生成结果的多样性 |
束分组 | 单一束组 | 多个独立的束组 |
多样性惩罚 | 无 | 有 (diversity_penalty ) |
适用场景 | 需要高质量、一致性的输出 | 需要多样化输出 |
相关参数
num_beams
: 定义了束的数量,即在每个解码步骤中保留多少个最高概率的候选序列。对于束搜索来说,这是唯一的束宽参数;对于组束搜索,则是指每个组中的束数,并且总束数应该是num_beams * num_beam_groups
。num_beam_groups
: 仅用于组束搜索,定义了束被分成多少个小组。默认值为1,此时相当于标准束搜索。diversity_penalty
: 仅用于组束搜索,用来惩罚来自同一组的重复词汇,以此增加不同组之间的多样性。默认值为0.0,表示没有额外的多样性惩罚。do_sample
: 控制是否启用采样模式。对于束搜索及其变体,通常将其设为False
,因为它们是确定性的搜索策略。
实际应用建议
- 如果你希望得到一个高质量、一致性高的输出,可以选择使用标准的束搜索,并适当调整
num_beams
来平衡计算成本和结果质量。 - 如果你需要生成多样化的输出,例如在创造性的任务中或者当你不想要重复的答案时,可以尝试使用组束搜索,并根据需要调整
num_beam_groups
和diversity_penalty
。 - 对于大多数情况下,如果你不确定应该选择哪种方法,先尝试标准束搜索,并根据实验结果决定是否需要引入更多多样性。
下面使用
transformers
库中的generate
方法进行标准束搜索和组束搜索的代码示例。这些例子假设你已经有了一个预训练的模型实例(如model
)和对应的分词器(如tokenizer
),并且你有一个输入文本需要生成输出。
标准束搜索示例
这是使用标准束搜索的简单示例,其中我们设置 num_beams
来定义束宽,并确保其他参数不会触发更复杂的搜索策略。
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# 假设已经加载了模型和分词器
model_name = "t5-small" # 示例模型名称
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 输入文本
input_text = "Translate English to French: Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt")
# 使用标准束搜索生成文本
outputs = model.generate(
**inputs,
max_length=50, # 设置最大生成长度
num_beams=4, # 设置束宽为4
early_stopping=True, # 如果达到结束标记,则提前停止
no_repeat_ngram_size=2, # 避免重复的n-grams
do_sample=False, # 关闭采样
)
# 解码输出
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)
组束搜索示例
对于组束搜索,我们需要额外配置 num_beam_groups
和 diversity_penalty
参数以启用多样性惩罚。
# 使用组束搜索生成文本
outputs_grouped = model.generate(
**inputs,
max_length=50, # 设置最大生成长度
num_beams=8, # 总束宽
num_beam_groups=4, # 分成4个束组
diversity_penalty=1.0, # 设置多样性惩罚系数
early_stopping=True, # 如果达到结束标记,则提前停止
do_sample=False, # 关闭采样
)
# 解码输出
generated_text_grouped = tokenizer.decode(outputs_grouped[0], skip_special_tokens=True)
print(generated_text_grouped)
注意事项
- 资源消耗:请注意,较大的
num_beams
或更多的num_beam_groups
会显著增加计算资源的消耗。因此,在生产环境中部署时要考虑性能影响。 - 参数调整:上面提供的参数值(例如
max_length
,num_beams
,diversity_penalty
等)仅作为示例。 - 特殊标记:
skip_special_tokens=True
在解码时用于忽略特殊的开始和结束标记,使得输出更加干净易读。