全面了解Beam Search 1

最近研究了一下用基于BERT的encoder-decoder结构做文本生成任务,碰巧管老师昨天的文章也介绍了以生成任务见长的GPT模型,于是决定用两篇文章大家介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索)。

解码及贪心搜索

生成式任务相比普通的分类、tagging等NLP任务会复杂不少。在生成的时候,模型的输出是一个时间步一个时间步依次获得的,而且前面时间步的结果还会影响后面时间步的结果。也就是说,每一个时间步,模型给出的都是基于历史生成结果的条件概率。为了生成完整的句子,需要一个称为解码的额外动作来融合模型多个时间步的输出,而且使得最终得到的序列的每一步条件概率连乘起来最大。

在文本生成任务中,每一个时间步可能的输出种类称为字典大小(vocabulary size,我们用 v v v表示),进行T步随机的生成可能获得的结果总共有 v T v^T vT种。拿中文文本生成来说, v v v的值大约是5000-6000,即常用汉字的个数。在如此大的基数下,遍历整个生成空间是不现实的。

最容易想到的策略是贪心搜索,即每一个时间步都取出一个条件概率最大的输出,再将从开始到当前步的结果作为输入去获得下一个时间步的输出,直到模型给出生成结束的标志。例如下图,每一个时间步都取出了条件概率最大一个结果,生成了序列[A,B,C]

贪心搜索示意图
贪心搜索示意图

很明显,这样做将原来指数级别的求解空间直接压缩到了与长度线性相关的大小。由于丢弃了绝大多数的可能解,这种关注当下的策略无法保证最终得到的序列概率是最优的。

Beam Search

而beam search是对贪心策略一个改进。思路也很简单,就是稍微放宽一些考察的范围。在每一个时间步,不再只保留当前分数最高的1个输出,而是保留num_beams个。当num_beams=1时集束搜索就退化成了贪心搜索。

下图是一个实际的例子,每个时间步有ABCDE共5种可能的输出,即 v = 5 v=5 v=5,图中的num_beams=2,也就是说每个时间步都会保留到当前步为止条件概率最优的2个序列。

在这里插入图片描述
Beam Search示意图

  • 在第一个时间步,A和C是最优的两个,因此得到了两个结果[A],[C],其他三个就被抛弃了;
  • 第二步会基于这两个结果继续进行生成,在A这个分支可以得到5个候选人,[AA],[AB],[AC],[AD],[AE],C也同理得到5个,此时会对这10个进行统一排名,再保留最优的两个,即图中的[AB][CE]
  • 第三步同理,也会从新的10个候选人里再保留最好的两个,最后得到了[ABD],[CED]两个结果。

可以发现,beam search在每一步需要考察的候选人数量是贪心搜索的num_beams倍,因此是一种牺牲时间换性能的方法。

以上就是Beam Search的基本概念,下面我们解析一种高效率实现方式。

Beam Search代码解析

Beam Search的原理虽然简单,但实际实现的时候却有很多细节要考虑。下面要解析这个实现出自于NLP界著名Python包Transformers,我为了说明方便做了一些改动。

一个正确且高效的算法需要处理的问题大概有两个:

  • 充分利用硬件,可以处理批量数据,且尽量使用并行计算少用循环
  • 处理好长短不同的生成结果

下面是基础版的beam search函数定义。其中context是编码器编码获得的向量,batch_size是每批数据中包含的样本量,bos_token_id是句子开头标志的token id,pad_token_id是用于填充的token id,eos_token_id是句子结束标志的token id。这里给参数填上的默认值和我们后面讲解时使用的例子是一致的。

def beam_search_generate(context,
                        batch_size=3,
                        max_length=20,
                        min_length=2,
                        num_beams=2,
                        bos_token_id=101,
                        pad_token_id=0,
                        eos_token_id=102,
                        ):
    pass

在函数中主要执行以下三个步骤:

  • 准备初始输入
  • 在当前生成的序列长度未达到max_length时扩展生成序列
  • 准备最终输出的序列

下面我们分别解析。

准备初始输入

# 建立beam容器,每个样本一个
generated_hyps = [
    BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
    for _ in range(batch_size)
]

# 每个beam容器的得分,共batch_size*num_beams个
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=encoder_input_ids.device)
beam_scores = beam_scores.view(-1)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值