Beam Search的学习笔记(附代码实现)

引言

Beam Search 是一种受限的宽度优先搜索方法,经常用在各种 NLP 生成类任务中,例如机器翻译、对话系统、文本摘要。本文首先介绍 Beam Search 的基本思想,然后再介绍一些beam search的优化方法,最后附上自己的代码实现。

1. Beam Search的基础版本

在生成文本的时候,通常需要进行解码操作,贪心搜索 (Greedy Search) 是比较简单的解码。Beam Search 对贪心搜索进行了改进,扩大了搜索空间,更容易得到全局最优解。Beam Search 包含一个参数 beam size k,表示每一时刻均保留得分最高的 k 个序列,然后下一时刻用这 k 个序列继续生成。示意图如下所示:
在这里插入图片描述
假设我们生成词表中有三个单词{我,爱,你}。我们设 K = 2 K=2 K=2。那么我们在第一时刻确定两个候选输出是{我,你}。紧接着我们要考虑第二个输出,具体步骤如下:

  • 确定单词“我”为第一时刻输出,并将其作为第二时刻输入,在已知 p ( x , 我 ) p(x,我) p(x)的情况下,各个单词的输出概率为3种情况,每个组合的概率为 P ( 我 ∣ x ) P ( y 2 ∣ x , 我 ) P(我|x)P(y_2|x,我) P(x)P(y2x,)
  • 同样我们把“你”也作为第二时刻输入,同样也有三种组合。
  • 最后我们在六种组合中选择概率最大的三个组合。

接下来要做的重复这个过程,逐步生成单词,直到遇到结束标识符停止。最后得到概率最大的那个生成序列。其概率为:在这里插入图片描述

以上就是Beam search算法的思想,当beam size=1时,就变成了贪心算法。

2. Beam Search的优化

Beam search算法也有许多改进的地方。

2.1 Length normalization:惩罚短句

根据最后的概率公式可知,该算法倾向于选择最短的句子,因为在这个连乘操作中,每个因子都是小于1的数,因子越多,最后的概率就越小。解决这个问题的方式,最后的概率值除以这个生成序列的单词数,这样比较的就是每个单词的平均概率大小。此外,连乘因子较多时,可能会超过浮点数的最小值,可以考虑取对数来缓解这个问题。谷歌给的公式如下:
在这里插入图片描述
其中α∈[0,1],谷歌建议取值为[0.6,0.7]之间,α用于length normalization。

2.2 Coverage normalization:惩罚重复

另外我们在序列到序列任务中经常会发现一个问题,2016 年, 华为诺亚方舟实验室的论文提到,机器翻译的时候会存在over translation or undertranslation due to attention coverage。 作者提出coverage-based atttention机制来解决coverage 问题。 Google machine system 利用了如下的方式进行了length normalization 和 coverage penalty。

还是上述公式,β用于控制coverage penalty
在这里插入图片描述

coverage penalty 主要用于使用 Attention 的场合,通过 coverage penalty 可以让 Decoder 均匀地关注于输入序列 x x x 的每一个 token,防止一些 token 获得过多的 Attention

2.3 End of sentence normalization:抑制长句

有的时候我们发现生成的序列一直生成下去不会停止,有的时候我们可以显式的设置最大生成长度进行控制,这里我们可以采用下式来进行约束:
在这里插入图片描述
其中 ∣ X ∣ |X| X是source的长度, ∣ Y ∣ |Y| Y是当前target的长度,那么由上式可知,target长度越长的话,上述得分越低,这样就会防止出现生成一直不停止的情况。


3. Beam Search的代码实现

总的来说,beam search不保证全局最优,但是比greedy search搜索空间更大,一般结果比greedy search要好。下面附上一些代码实现:

首先,首先定义一个 Beam 类,作为一个存放候选序列的容器,属性需维护当前序列中的 token 以及对应的对数概率,同时还需维护跟当前 timestep 的 Decoder 相关的一些变量。此外,还需要给 Beam 类实现两个函数:一个 extend 函数用以扩展当前的序列(即添加新的 time step的 token 及相关变量);一个 score 函数用来计算当前序列的分数(在Beam类下的seq_score函数中有Length normalization以及Coverage normalization)。

class Beam(object):
    def __init__(self,
                 tokens,
                 log_probs,
                 decoder_states,
                 coverage_vector):
        self.tokens = tokens
        self.log_probs = log_probs
        self.decoder_states = decoder_states
        self.coverage_vector = coverage_vector
    def extend(self,
               token,
               log_prob,
               decoder_states,
               coverage_vector):
        return Beam(tokens=self.tokens + [token],
                    log_probs=self.log_probs + [log_prob],
                    decoder_states=decoder_states,
                    coverage_vector=coverage_vector)

    def seq_score(self):
        """
        This function calculate the score of the current sequence.
        """
        len_Y = len(self.tokens)
        # Lenth normalization
        ln = (5+len_Y)**config.alpha / (5+1)**config.alpha
        cn = config.beta * torch.sum(  # Coverage normalization
            torch.log(
                config.eps +
                torch.where(
                    self.coverage_vector < 1.0,
                    self.coverage_vector,
                    torch.ones((1, self.coverage_vector.shape[1])).to(torch.device(config.DEVICE))
                )
            )
        )

        score = sum(self.log_probs) / ln + cn
        return score

    def __lt__(self, other):
        return self.seq_score() < other.seq_score()

    def __le__(self, other):
        return self.seq_score() <= other.seq_score()

接着我们需要实现一个 best_k 函数,作用是将一个 Beam 容器中当前 time step 的变量传入 Decoder 中,计算出新一轮的词表概率分布,并从中选出概率最大的 k 个 token 来扩展当前序列(其中加入了End of sentence normalization),得到 k 个新的候选序列。

    def best_k(self, beam, k, encoder_output, x_padding_masks, x, len_oovs):
        """Get best k tokens to extend the current sequence at the current time step.
        """
        # use decoder to generate vocab distribution for the next token
        x_t = torch.tensor(beam.tokens[-1]).reshape(1, 1)
        x_t = x_t.to(self.DEVICE)

        # Get context vector from attention network.
        context_vector, attention_weights, coverage_vector = \
            self.model.attention(beam.decoder_states,
                                 encoder_output,
                                 x_padding_masks,
                                 beam.coverage_vector)

        # Replace the indexes of OOV words with the index of OOV token
        # to prevent index-out-of-bound error in the decoder.

        p_vocab, decoder_states, p_gen = \
            self.model.decoder(replace_oovs(x_t, self.vocab),
                               beam.decoder_states,
                               context_vector)

        final_dist = self.model.get_final_distribution(x,
                                                       p_gen,
                                                       p_vocab,
                                                       attention_weights,
                                                       torch.max(len_oovs))
        # Calculate log probabilities.
        log_probs = torch.log(final_dist.squeeze())
        # Filter forbidden tokens.
        # EOS token penalty. Follow the definition in
        # https://opennmt.net/OpenNMT/translation/beam_search/.
        log_probs[self.vocab.EOS] *= \
            config.gamma * x.size()[1] / len(beam.tokens)
        log_probs[self.vocab.UNK] = -float('inf')
        # Get top k tokens and the corresponding logprob.
        topk_probs, topk_idx = torch.topk(log_probs, k)

        # Extend the current hypo with top k tokens, resulting k new hypos.
        best_k = [beam.extend(x,
                  log_probs[x],
                  decoder_states,
                  coverage_vector) for x in topk_idx.tolist()]

        return best_k

最后我们实现主函数 beam_search。初始化encoder、attention和decoder的输⼊,然后对于每⼀个decodestep,对于现有的k个beam,我们分别利⽤best_k函数来得到各⾃最佳的k个extended beam,也就是每个decode step我们会得到k*k个新的beam,然后只保留分数最⾼的k个,作为下⼀轮需要扩展的k个beam。为了只保留分数最⾼的k个beam,我们可以⽤⼀个堆(heap)来实现,堆的中只保存k个节点,根结点保存分数最低的beam。

    def beam_search(self,
                    x,
                    max_sum_len,
                    beam_width,
                    len_oovs,
                    x_padding_masks):
        """Using beam search to generate summary.
        """
        # run body_sequence input through encoder
        encoder_output, encoder_states = self.model.encoder(
            replace_oovs(x, self.vocab))
        coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE)
        # initialize decoder states with encoder forward states
        decoder_states = self.model.reduce_state(encoder_states)

        # initialize the hypothesis with a class Beam instance.

        init_beam = Beam([self.vocab.SOS],
                         [0],
                         decoder_states,
                         coverage_vector)

        # get the beam size and create a list for stroing current candidates
        # and a list for completed hypothesis
        k = beam_width
        curr, completed = [init_beam], []

        # use beam search for max_sum_len (maximum length) steps
        for _ in range(max_sum_len):
            # get k best hypothesis when adding a new token

            topk = []
            for beam in curr:
                # When an EOS token is generated, add the hypo to the completed
                # list and decrease beam size.
                if beam.tokens[-1] == self.vocab.EOS:
                    completed.append(beam)
                    k -= 1
                    continue
                for can in self.best_k(beam,
                                       k,
                                       encoder_output,
                                       x_padding_masks,
                                       x,
                                       torch.max(len_oovs)
                                       ):
                    # Using topk as a heap to keep track of top k candidates.
                    # Using the sequence scores of the hypos to campare
                    # and object ids to break ties.
                    add2heap(topk, (can.seq_score(), id(can), can), k)

            curr = [items[2] for items in topk]
            # stop when there are enough completed hypothesis
            if len(completed) == beam_width:
                break
        # When there are not engouh completed hypotheses,
        # take whatever when have in current best k as the final candidates.
        completed += curr
        # sort the hypothesis by normalized probability and choose the best one
        result = sorted(completed,
                        key=lambda x: x.seq_score(),
                        reverse=True)[0].tokens
        return result
Beam search(束搜索)是一种常用的序列生成算法,它在生成一个序列时,每一步都会保留 $k$ 个最有可能的候选项,然后选择得分最高的其中一个作为下一步的输入,并重复此过程直到生成完整个序列。 以下是一个 Python 实现的简单的 beam search 算法代码示例: ```python import math def beam_search(start_sequence, model, beam_size, max_len): """ 使用 beam search 算法生成序列。 参数: start_sequence:起始序列,一般为特殊起始符号。 model:生成模型。 beam_size:束宽,即保留的最有可能的候选项数量。 max_len:生成序列的最大长度。 返回: 生成的序列。 """ # 将起始序列作为第一步的输入。 candidates = [{'tokens': [start_sequence], 'score': 0}] # 不断扩展序列,直到最大长度。 for _ in range(max_len): # 生成下一步的所有可能的候选项。 next_candidates = [] for candidate in candidates: # 获取当前序列的最后一个标记。 last_token = candidate['tokens'][-1] # 使用模型预测下一个标记和对应的得分。 scores = model(last_token) # 选择得分最高的前 beam_size 个候选项。 top_scores, top_tokens = torch.topk(scores, beam_size) for score, token in zip(top_scores, top_tokens): # 计算候选项的总得分。 total_score = candidate['score'] + score.item() # 将新的候选项加入到候选项列表中。 next_candidates.append({'tokens': candidate['tokens'] + [token], 'score': total_score}) # 保留得分最高的前 beam_size 个候选项。 sorted_candidates = sorted(next_candidates, key=lambda x: -x['score'])[:beam_size] candidates = sorted_candidates # 返回得分最高的序列。 best_sequence = candidates[0]['tokens'] return best_sequence ``` 这个代码示例中,输入参数包括起始标记、生成模型、束宽和最大序列长度。在算法的主循环中,首先将起始标记作为第一步的输入,然后不断扩展序列直到达到最大长度为止。在每一步中,使用模型预测下一个标记的概率分布,并选择得分最高的前 $k$ 个候选项作为下一步的输入。然后计算每个候选项的总得分,并保留得分最高的前 $k$ 个候选项。最后返回得分最高的序列。 需要注意的是,这个示例代码中的模型预测函数 `model` 需要根据具体的应用场景进行实现。此外,这个算法还有一些改进的空间,比如可以使用剪枝等技巧来提高算法的效率和准确性。
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值