束搜索、组束搜索介绍、异同及相关参数

束搜索(Beam Search)与组束搜索(Group Beam Search)介绍

束搜索(Beam Search)

束搜索是一种用于序列生成任务的解码算法,广泛应用于自然语言处理领域,如机器翻译、语音识别等。它的工作原理是:

  • 在每个解码步骤中,不是简单地选择最有可能的下一个词(这被称为贪心搜索),而是考虑前num_beams个最有可能的候选词。
  • 这些候选词形成多个可能的序列路径,这些路径在后续步骤中继续扩展。
  • 最终,当达到最大长度或遇到结束标记时,从所有路径中选择得分最高的序列作为最终输出。

优点:

  • 相比于贪心搜索,它更有可能找到全局最优解。
  • 通过调整束宽(num_beams)可以在计算成本和结果质量之间取得平衡。

缺点:

  • 计算复杂度随着束宽的增加而显著提升。
  • 可能偏向较短的句子,因为较长的句子概率连乘后会变得非常小。
组束搜索(Group Beam Search)

组束搜索是束搜索的一种变体,旨在提高生成结果的多样性。它的主要特点包括:

  • 将总束分为若干个小组(由num_beam_groups指定),每个小组独立进行束搜索。
  • 使用diversity_penalty参数来惩罚来自同一组的重复词汇,从而鼓励不同组之间的差异性。
  • 每个组内的束仍然遵循标准束搜索的规则,但是由于引入了多样性惩罚,不同组的结果更加多样化。

优点:

  • 提高了生成文本的多样性,这对于某些应用场景(如多答案问题回答、创意写作)非常重要。
  • 可以帮助避免模型生成过于相似的输出,特别是在需要探索多种可能性的情况下。

缺点:

  • 需要更多的计算资源,因为它实际上是在并行运行多个束搜索。
  • 结果的质量可能不如单个大束宽的束搜索稳定,取决于如何设置多样性和束组数量。

异同点

特征束搜索 (Beam Search)组束搜索 (Group Beam Search)
目的找到最有可能的单一序列增加生成结果的多样性
束分组单一束组多个独立的束组
多样性惩罚有 (diversity_penalty)
适用场景需要高质量、一致性的输出需要多样化输出

相关参数

  • num_beams: 定义了束的数量,即在每个解码步骤中保留多少个最高概率的候选序列。对于束搜索来说,这是唯一的束宽参数;对于组束搜索,则是指每个组中的束数,并且总束数应该是 num_beams * num_beam_groups
  • num_beam_groups: 仅用于组束搜索,定义了束被分成多少个小组。默认值为1,此时相当于标准束搜索。
  • diversity_penalty: 仅用于组束搜索,用来惩罚来自同一组的重复词汇,以此增加不同组之间的多样性。默认值为0.0,表示没有额外的多样性惩罚。
  • do_sample: 控制是否启用采样模式。对于束搜索及其变体,通常将其设为False,因为它们是确定性的搜索策略。

实际应用建议

  • 如果你希望得到一个高质量、一致性高的输出,可以选择使用标准的束搜索,并适当调整num_beams来平衡计算成本和结果质量。
  • 如果你需要生成多样化的输出,例如在创造性的任务中或者当你不想要重复的答案时,可以尝试使用组束搜索,并根据需要调整num_beam_groupsdiversity_penalty
  • 对于大多数情况下,如果你不确定应该选择哪种方法,先尝试标准束搜索,并根据实验结果决定是否需要引入更多多样性。

下面使用 transformers 库中的 generate方法进行标准束搜索和组束搜索的代码示例。这些例子假设你已经有了一个预训练的模型实例(如 model)和对应的分词器(如tokenizer),并且你有一个输入文本需要生成输出。

标准束搜索示例

这是使用标准束搜索的简单示例,其中我们设置 num_beams 来定义束宽,并确保其他参数不会触发更复杂的搜索策略。

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# 假设已经加载了模型和分词器
model_name = "t5-small"  # 示例模型名称
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 输入文本
input_text = "Translate English to French: Hello, how are you?"
inputs = tokenizer(input_text, return_tensors="pt")

# 使用标准束搜索生成文本
outputs = model.generate(
    **inputs,
    max_length=50,  # 设置最大生成长度
    num_beams=4,    # 设置束宽为4
    early_stopping=True,  # 如果达到结束标记,则提前停止
    no_repeat_ngram_size=2,  # 避免重复的n-grams
    do_sample=False,        # 关闭采样
)

# 解码输出
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

组束搜索示例

对于组束搜索,我们需要额外配置 num_beam_groupsdiversity_penalty 参数以启用多样性惩罚。

# 使用组束搜索生成文本
outputs_grouped = model.generate(
    **inputs,
    max_length=50,          # 设置最大生成长度
    num_beams=8,            # 总束宽
    num_beam_groups=4,      # 分成4个束组
    diversity_penalty=1.0,  # 设置多样性惩罚系数
    early_stopping=True,    # 如果达到结束标记,则提前停止
    do_sample=False,        # 关闭采样
)

# 解码输出
generated_text_grouped = tokenizer.decode(outputs_grouped[0], skip_special_tokens=True)
print(generated_text_grouped)

注意事项

  • 资源消耗:请注意,较大的 num_beams 或更多的 num_beam_groups 会显著增加计算资源的消耗。因此,在生产环境中部署时要考虑性能影响。
  • 参数调整:上面提供的参数值(例如 max_length, num_beams, diversity_penalty 等)仅作为示例。
  • 特殊标记skip_special_tokens=True 在解码时用于忽略特殊的开始和结束标记,使得输出更加干净易读。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值