引言
Beam Search 是一种受限的宽度优先搜索方法,经常用在各种 NLP 生成类任务中,例如机器翻译、对话系统、文本摘要。本文首先介绍 Beam Search 的基本思想,然后再介绍一些beam search的优化方法,最后附上自己的代码实现。
1. Beam Search的基础版本
在生成文本的时候,通常需要进行解码操作,贪心搜索 (Greedy Search) 是比较简单的解码。Beam Search 对贪心搜索进行了改进,扩大了搜索空间,更容易得到全局最优解。Beam Search 包含一个参数 beam size k,表示每一时刻均保留得分最高的 k 个序列,然后下一时刻用这 k 个序列继续生成。示意图如下所示:
假设我们生成词表中有三个单词{我,爱,你}。我们设
K
=
2
K=2
K=2。那么我们在第一时刻确定两个候选输出是{我,你}。紧接着我们要考虑第二个输出,具体步骤如下:
- 确定单词“我”为第一时刻输出,并将其作为第二时刻输入,在已知 p ( x , 我 ) p(x,我) p(x,我)的情况下,各个单词的输出概率为3种情况,每个组合的概率为 P ( 我 ∣ x ) P ( y 2 ∣ x , 我 ) P(我|x)P(y_2|x,我) P(我∣x)P(y2∣x,我)。
- 同样我们把“你”也作为第二时刻输入,同样也有三种组合。
- 最后我们在六种组合中选择概率最大的三个组合。
接下来要做的重复这个过程,逐步生成单词,直到遇到结束标识符停止。最后得到概率最大的那个生成序列。其概率为:
以上就是Beam search算法的思想,当beam size=1时,就变成了贪心算法。
2. Beam Search的优化
Beam search算法也有许多改进的地方。
2.1 Length normalization:惩罚短句
根据最后的概率公式可知,该算法倾向于选择最短的句子,因为在这个连乘操作中,每个因子都是小于1的数,因子越多,最后的概率就越小。解决这个问题的方式,最后的概率值除以这个生成序列的单词数,这样比较的就是每个单词的平均概率大小。此外,连乘因子较多时,可能会超过浮点数的最小值,可以考虑取对数来缓解这个问题。谷歌给的公式如下:
其中α∈[0,1],谷歌建议取值为[0.6,0.7]之间,α用于length normalization。
2.2 Coverage normalization:惩罚重复
另外我们在序列到序列任务中经常会发现一个问题,2016 年, 华为诺亚方舟实验室的论文提到,机器翻译的时候会存在over translation or undertranslation due to attention coverage。 作者提出coverage-based atttention机制来解决coverage 问题。 Google machine system 利用了如下的方式进行了length normalization 和 coverage penalty。
还是上述公式,β用于控制coverage penalty
coverage penalty 主要用于使用 Attention 的场合,通过 coverage penalty 可以让 Decoder 均匀地关注于输入序列 x x x 的每一个 token,防止一些 token 获得过多的 Attention。
2.3 End of sentence normalization:抑制长句
有的时候我们发现生成的序列一直生成下去不会停止,有的时候我们可以显式的设置最大生成长度进行控制,这里我们可以采用下式来进行约束:
其中
∣
X
∣
|X|
∣X∣是source的长度,
∣
Y
∣
|Y|
∣Y∣是当前target的长度,那么由上式可知,target长度越长的话,上述得分越低,这样就会防止出现生成一直不停止的情况。
3. Beam Search的代码实现
总的来说,beam search不保证全局最优,但是比greedy search搜索空间更大,一般结果比greedy search要好。下面附上一些代码实现:
首先,首先定义一个 Beam
类,作为一个存放候选序列的容器,属性需维护当前序列中的 token 以及对应的对数概率,同时还需维护跟当前 timestep 的 Decoder 相关的一些变量。此外,还需要给 Beam 类实现两个函数:一个 extend
函数用以扩展当前的序列(即添加新的 time step的 token 及相关变量);一个 score
函数用来计算当前序列的分数(在Beam类下的seq_score函数中有Length normalization以及Coverage normalization)。
class Beam(object):
def __init__(self,
tokens,
log_probs,
decoder_states,
coverage_vector):
self.tokens = tokens
self.log_probs = log_probs
self.decoder_states = decoder_states
self.coverage_vector = coverage_vector
def extend(self,
token,
log_prob,
decoder_states,
coverage_vector):
return Beam(tokens=self.tokens + [token],
log_probs=self.log_probs + [log_prob],
decoder_states=decoder_states,
coverage_vector=coverage_vector)
def seq_score(self):
"""
This function calculate the score of the current sequence.
"""
len_Y = len(self.tokens)
# Lenth normalization
ln = (5+len_Y)**config.alpha / (5+1)**config.alpha
cn = config.beta * torch.sum( # Coverage normalization
torch.log(
config.eps +
torch.where(
self.coverage_vector < 1.0,
self.coverage_vector,
torch.ones((1, self.coverage_vector.shape[1])).to(torch.device(config.DEVICE))
)
)
)
score = sum(self.log_probs) / ln + cn
return score
def __lt__(self, other):
return self.seq_score() < other.seq_score()
def __le__(self, other):
return self.seq_score() <= other.seq_score()
接着我们需要实现一个 best_k
函数,作用是将一个 Beam 容器中当前 time step 的变量传入 Decoder 中,计算出新一轮的词表概率分布,并从中选出概率最大的 k 个 token 来扩展当前序列(其中加入了End of sentence normalization),得到 k 个新的候选序列。
def best_k(self, beam, k, encoder_output, x_padding_masks, x, len_oovs):
"""Get best k tokens to extend the current sequence at the current time step.
"""
# use decoder to generate vocab distribution for the next token
x_t = torch.tensor(beam.tokens[-1]).reshape(1, 1)
x_t = x_t.to(self.DEVICE)
# Get context vector from attention network.
context_vector, attention_weights, coverage_vector = \
self.model.attention(beam.decoder_states,
encoder_output,
x_padding_masks,
beam.coverage_vector)
# Replace the indexes of OOV words with the index of OOV token
# to prevent index-out-of-bound error in the decoder.
p_vocab, decoder_states, p_gen = \
self.model.decoder(replace_oovs(x_t, self.vocab),
beam.decoder_states,
context_vector)
final_dist = self.model.get_final_distribution(x,
p_gen,
p_vocab,
attention_weights,
torch.max(len_oovs))
# Calculate log probabilities.
log_probs = torch.log(final_dist.squeeze())
# Filter forbidden tokens.
# EOS token penalty. Follow the definition in
# https://opennmt.net/OpenNMT/translation/beam_search/.
log_probs[self.vocab.EOS] *= \
config.gamma * x.size()[1] / len(beam.tokens)
log_probs[self.vocab.UNK] = -float('inf')
# Get top k tokens and the corresponding logprob.
topk_probs, topk_idx = torch.topk(log_probs, k)
# Extend the current hypo with top k tokens, resulting k new hypos.
best_k = [beam.extend(x,
log_probs[x],
decoder_states,
coverage_vector) for x in topk_idx.tolist()]
return best_k
最后我们实现主函数 beam_search
。初始化encoder、attention和decoder的输⼊,然后对于每⼀个decodestep,对于现有的k个beam,我们分别利⽤best_k
函数来得到各⾃最佳的k个extended beam,也就是每个decode step我们会得到k*k个新的beam,然后只保留分数最⾼的k个,作为下⼀轮需要扩展的k个beam。为了只保留分数最⾼的k个beam,我们可以⽤⼀个堆(heap)来实现,堆的中只保存k个节点,根结点保存分数最低的beam。
def beam_search(self,
x,
max_sum_len,
beam_width,
len_oovs,
x_padding_masks):
"""Using beam search to generate summary.
"""
# run body_sequence input through encoder
encoder_output, encoder_states = self.model.encoder(
replace_oovs(x, self.vocab))
coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE)
# initialize decoder states with encoder forward states
decoder_states = self.model.reduce_state(encoder_states)
# initialize the hypothesis with a class Beam instance.
init_beam = Beam([self.vocab.SOS],
[0],
decoder_states,
coverage_vector)
# get the beam size and create a list for stroing current candidates
# and a list for completed hypothesis
k = beam_width
curr, completed = [init_beam], []
# use beam search for max_sum_len (maximum length) steps
for _ in range(max_sum_len):
# get k best hypothesis when adding a new token
topk = []
for beam in curr:
# When an EOS token is generated, add the hypo to the completed
# list and decrease beam size.
if beam.tokens[-1] == self.vocab.EOS:
completed.append(beam)
k -= 1
continue
for can in self.best_k(beam,
k,
encoder_output,
x_padding_masks,
x,
torch.max(len_oovs)
):
# Using topk as a heap to keep track of top k candidates.
# Using the sequence scores of the hypos to campare
# and object ids to break ties.
add2heap(topk, (can.seq_score(), id(can), can), k)
curr = [items[2] for items in topk]
# stop when there are enough completed hypothesis
if len(completed) == beam_width:
break
# When there are not engouh completed hypotheses,
# take whatever when have in current best k as the final candidates.
completed += curr
# sort the hypothesis by normalized probability and choose the best one
result = sorted(completed,
key=lambda x: x.seq_score(),
reverse=True)[0].tokens
return result