动手学深度学习(四十五)——束搜索

束搜索

  在seq2seq中,我们逐个地预测输出序列的标记,直到预测序列中出现序列结束标记“<eos>”。在本节中,我们将首先对这种 贪心搜索(greedy search)策略进行介绍,并探讨其存在的问题,然后对比这种策略与其他替代策略:穷举搜索(exhaustive search)和束搜索(beam search)。

  在正式介绍贪心搜索之前,让我们使用seq2seq中相同的数学符号定义搜索问题。在任意时间步 t ′ t' t,解码器输出 y t ′ y_{t'} yt 的概率取决于时间步 t ′ t' t 之前的输出子序列 y 1 , … , y t ′ − 1 y_1, \ldots, y_{t'-1} y1,,yt1 和输入序列的信息编码成的上下文变量 c \mathbf{c} c。为了量化计算成本,用 Y \mathcal{Y} Y(它包含“<eos>”)表示输出词汇表。所以这个词汇集合的基数 ∣ Y ∣ \left|\mathcal{Y}\right| Y 就是词汇表的大小。我们还将输出序列的最大标记数指定为 T ′ T' T。因此,我们的目标是从所有 O ( ∣ Y ∣ T ′ ) \mathcal{O}(\left|\mathcal{Y}\right|^{T'}) O(YT) 个可能的输出序列中寻找理想的输出。当然,对于所有输出序列,这些序列中包含的“<eos>”及其之后的部分将在实际输出中丢弃。

一、贪心搜索

  首先,让我们看看一个简单的策略:贪心搜索。该策略已用于seq2seq的序列预测。对于输出序列的任何时间步 t ′ t' t,我们都将基于贪心搜索从 Y \mathcal{Y} Y 中找到具有最高条件概率的标记,即:

y t ′ = argmax ⁡ y ∈ Y P ( y ∣ y 1 , … , y t ′ − 1 , c ) y_{t'} = \operatorname*{argmax}_{y \in \mathcal{Y}} P(y \mid y_1, \ldots, y_{t'-1}, \mathbf{c}) yt=yYargmaxP(yy1,,yt1,c)

一旦输出序列包含了“<eos>”或者达到其最大长度 T ′ T' T,则输出完成。

那么贪心搜索存在什么问题呢?

  实际上,最优序列(optimal sequence)应该是最大化 ∏ t ′ = 1 T ′ P ( y t ′ ∣ y 1 , … , y t ′ − 1 , c ) \prod_{t'=1}^{T'} P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c}) t=1TP(yty1,,yt1,c) 值的输出序列,这是基于输入序列生成输出序列的条件概率。不幸的是,无法保证通过贪心搜索得到最优序列。

  让我们用一个例子来描述。假设输出中有四个标记“A”、“B”、“C”和“<eos>”。 在上图中,每个时间步下的四个数字分别表示在该时间步生成“A”、“B”、“C”和“<eos>”的条件概率。在每个时间步,贪心搜索选择具有最高条件概率的标记。因此,将在图中中预测输出序列“A”、“B”、“C”和“<eos>”。这个输出序列的条件概率是 0.5 × 0.4 × 0.4 × 0.6 = 0.048 0.5 \times 0.4 \times 0.4 \times 0.6 = 0.048 0.5×0.4×0.4×0.6=0.048

  接下来,让我们看看上图的另一个例子。与图一展示不同,在时间步2中,我们选择图一中的标记“C”,它具有 第二 高的条件概率。由于时间步3所基于的时间步1和2处的输出子序列已从图一中的“A”和“B”改变为“A”和“C”,因此时间步3处的每个标记的条件概率也在图2中改变。假设我们在时间步3选择标记“B”。现在,时间步4以前三个时间步“A”、“C”和“B”的输出子序列为条件,这与图一中的“A”、“B”和“C”不同。因此,在图2中的时间步4生成每个标记的条件概率也不同于图1中的条件概率。结果,图2中的输出序列“A”、“C”、“B”和“<eos>”的条件概率为 0.5 × 0.3 × 0.6 × 0.6 = 0.054 0.5\times0.3 \times0.6\times0.6=0.054 0.5×0.3×0.6×0.6=0.054,这大于图1中的贪心搜索的条件概率。

  在本例中,通过贪心搜索获得的输出序列“A”、“B”、“C”和“<eos>”不是最佳序列。

二、穷举搜索

  如果目标是获得最优序列,我们可以考虑使用 穷举搜索(exhaustive search):穷举地枚举所有可能的输出序列及其条件概率,然后输出条件概率最高的一个。

  虽然我们可以使用穷举搜索来获得最优序列,但其计算量 O ( ∣ Y ∣ T ′ ) \mathcal{O}(\left|\mathcal{Y}\right|^{T'}) O(YT) 可能过高。例如,当 ∣ Y ∣ = 10000 |\mathcal{Y}|=10000 Y=10000 T ′ = 10 T'=10 T=10 时,我们需要评估 1000 0 10 = 1 0 40 10000^{10} = 10^{40} 1000010=1040 序列。这几乎是不可能的。另一方面,贪心搜索的计算量是 O ( ∣ Y ∣ T ′ ) \mathcal{O}(\left|\mathcal{Y}\right|T') O(YT):它通常明显小于穷举搜索。例如,当 ∣ Y ∣ = 10000 |\mathcal{Y}|=10000 Y=10000 T ′ = 10 T'=10 T=10 时,我们只需要评估 10000 × 10 = 1 0 5 10000\times10=10^5 10000×10=105 个序列。

三、束搜索

  决定序列搜索策略取决于一个范围,在任何一个极端情况下都有问题。如果只有准确性最重要?则显然是穷举搜索。如果计算成本最重要?则显然是贪心搜索。实际应用则介于这两个极端之间。

  束搜索(beam search)是贪心搜索的改进版本。它有一个超参数,名为 束宽(beam size) k k k。在时间步 1 1 1,我们选择具有最高条件概率的 k k k 个标记。这 k k k 个标记将分别是 k k k 个候选输出序列的第一个标记。在随后的每个时间步,基于上一时间步的 k k k 个候选输出序列,我们将继续从 k ∣ Y ∣ k\left|\mathcal{Y}\right| kY 个可能的选择中挑出具有最高条件概率的 k k k 个候选输出序列。

  上图演示了束搜索的过程。假设输出的词汇表只包含五个元素: Y = { A , B , C , D , E } \mathcal{Y} = \{A, B, C, D, E\} Y={A,B,C,D,E},其中有一个是“<eos>”。设置束宽为2,输出序列的最大长度为3。在时间步1,假设具有最高条件概率 P ( y 1 ∣ c ) P(y_1 \mid \mathbf{c}) P(y1c)的标记是 A A A C C C。在时间步2,我们计算所有 y 2 ∈ Y y_2 \in \mathcal{Y} y2Y

P ( A , y 2 ∣ c ) = P ( A ∣ c ) P ( y 2 ∣ A , c ) , P ( C , y 2 ∣ c ) = P ( C ∣ c ) P ( y 2 ∣ C , c ) , \begin{aligned}P(A, y_2 \mid \mathbf{c}) = P(A \mid \mathbf{c})P(y_2 \mid A, \mathbf{c}),\\ P(C, y_2 \mid \mathbf{c}) = P(C \mid \mathbf{c})P(y_2 \mid C, \mathbf{c}),\end{aligned} P(A,y2c)=P(Ac)P(y2A,c),P(C,y2c)=P(Cc)P(y2C,c),

从这十个值中选择最大的两个,比如 P ( A , B ∣ c ) P(A, B \mid \mathbf{c}) P(A,Bc) P ( C , E ∣ c ) P(C, E \mid \mathbf{c}) P(C,Ec)。然后在时间步3,对于所有 y 3 ∈ Y y_3 \in \mathcal{Y} y3Y,我们计算:

P ( A , B , y 3 ∣ c ) = P ( A , B ∣ c ) P ( y 3 ∣ A , B , c ) , P ( C , E , y 3 ∣ c ) = P ( C , E ∣ c ) P ( y 3 ∣ C , E , c ) , \begin{aligned}P(A, B, y_3 \mid \mathbf{c}) = P(A, B \mid \mathbf{c})P(y_3 \mid A, B, \mathbf{c}),\\P(C, E, y_3 \mid \mathbf{c}) = P(C, E \mid \mathbf{c})P(y_3 \mid C, E, \mathbf{c}),\end{aligned} P(A,B,y3c)=P(A,Bc)P(y3A,B,c),P(C,E,y3c)=P(C,Ec)P(y3C,E,c),

从这十个值中选择最大的两个,即 P ( A , B , D ∣ c ) P(A, B, D \mid \mathbf{c}) P(A,B,Dc) P ( C , E , D ∣ c ) P(C, E, D \mid \mathbf{c}) P(C,E,Dc)。结果,我们得到六个候选输出序列:(1) A A A;(2) C C C;(3) A , B A,B A,B;(4) C , E C,E C,E;(5) A , B , D A,B,D A,B,D ;(6) C , E , D C,E,D C,E,D

最后,我们基于这六个序列(例如,丢弃包括“<eos>”和之后的部分)获得最终候选输出序列集合。然后我们选择以下得分最高的序列作为输出序列:

1 L α log ⁡ P ( y 1 , … , y L ) = 1 L α ∑ t ′ = 1 L log ⁡ P ( y t ′ ∣ y 1 , … , y t ′ − 1 , c ) , \frac{1}{L^\alpha} \log P(y_1, \ldots, y_{L}) = \frac{1}{L^\alpha} \sum_{t'=1}^L \log P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c}), Lα1logP(y1,,yL)=Lα1t=1LlogP(yty1,,yt1,c),

其中 L L L 是最终候选序列的长度, α \alpha α 通常设置为0.75。因为一个较长的序列在上面式子的求和中会有更多的对数项,因此分母中的 L α L^\alpha Lα 用于惩罚长序列。

  束搜索的计算量为 O ( k ∣ Y ∣ T ′ ) \mathcal{O}(k\left|\mathcal{Y}\right|T') O(kYT)。这个结果介于贪心搜索和穷举搜索之间。实际上,贪心搜索可以看作是一种束宽为1的特殊类型的束搜索。通过灵活地选择束宽,束搜索可以在精度和计算成本之间进行权衡。

小结

  • 序列搜索策略包括贪心搜索、穷举搜索和束搜索。
  • 束搜索通过灵活选择束宽,在精度和计算成本之间找到平衡

练习

  1. 我们可以把穷举搜索看作一种特殊的束搜索吗?为什么?

k = 1 是贪心搜索;k = n的时候是穷举搜索

  1. 在seq2seq的机器翻译问题中应用束搜索。束宽如何影响结果和预测速度?

束宽越小,其穷举的次数越小,所以预测速度会更快,但是其结果可能并非最优结果

  1. 在seq2seq中,我们使用语言模型来生成用户提供前缀的文本。它使用了哪种搜索策略?你能改进吗?

Discussions

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

留小星

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值