Beam search 算法在文本生成中用得比较多,用于选择较优的结果(可能并不是最优的)。接下来将以seq2seq机器翻译为例来说明这个Beam search的算法思想。
在机器翻译中,beam search算法在测试的时候用的,因为在训练过程中,每一个decoder的输出是有与之对应的正确答案做参照,也就不需要beam search去加大输出的准确率。
有如下从中文到英语的翻译:
中文:
我 爱 学习,学习 使 我 快乐
英语:
I love learning, learning makes me happy
在这个测试中,中文的词汇表是{我,爱,学习,使,快乐},长度为5。英语的词汇表是{I, love, learning, make, me, happy}(全部转化为小写),长度为6。那么首先使用seq2seq中的编码器对中文序列(记这个中文序列为
X
X
X)进行编码,得到语义向量
C
C
C。
得到语义向量
C
C
C后,进入解码阶段,依次翻译成目标语言。在正式解码之前,有一个参数需要设置,那就是beam search中的beam size,这个参数就相当于top-k中的k,选择前k个最有可能的结果。在本例中,我们选择beam size=3。
来看解码器的第一个输出 y 1 y_1 y1,在给定语义向量 C C C的情况下,首先选择英语词汇表中最有可能k个单词,也就是依次选择条件概率 P ( y 1 ∣ C ) P(y_1|C) P(y1∣C)前3大对应的单词,比如这里概率最大的前三个单词依次是 I I I, l e a r n i n g learning learning, h a p p y happy happy。
接着生成第二个输出 y 2 y_2 y2,在这个时候我们得到了那些东西呢,首先我们得到了编码阶段的语义向量 C C C,还有第一个输出 y 1 y_1 y1。此时有个问题, y 1 y_1 y1有三个,怎么作为这一时刻的输入呢(解码阶段需要将前一时刻的输出作为当前时刻的输入),答案就是都试下,具体做法是:
- 确定 I I I为第一时刻的输出,将其作为第二时刻的输入,得到在已知 ( C , I ) (C, I) (C,I)的条件下,各个单词作为该时刻输出的条件概率 P ( y 2 ∣ C , I ) P(y_2|C,I) P(y2∣C,I),有6个组合,每个组合的概率为 P ( I ∣ C ) P ( y 2 ∣ C , I ) P(I|C)P(y_2|C, I) P(I∣C)P(y2∣C,I)。
- 确定 l e a r n i n g learning learning为第一时刻的输出,将其作为第二时刻的输入,得到该条件下,词汇表中各个单词作为该时刻输出的条件概率 P ( y 2 ∣ C , l e a r n i n g ) P(y_2|C, learning) P(y2∣C,learning),这里同样有6种组合;
- 确定 h a p p y happy happy为第一时刻的输出,将其作为第二时刻的输入,得到该条件下各个单词作为输出的条件概率 P ( y 2 ∣ C , h a p p y ) P(y_2|C, happy) P(y2∣C,happy),得到6种组合,概率的计算方式和前面一样。
这样就得到了18个组合,每一种组合对应一个概率值
P
(
y
1
∣
C
)
P
(
y
2
∣
C
,
y
1
)
P(y_1|C)P(y_2|C, y_1)
P(y1∣C)P(y2∣C,y1),接着在这18个组合中选择概率值top3的那三种组合,假设得到
I
l
o
v
e
I love
Ilove,
I
h
a
p
p
y
I happy
Ihappy,
l
e
a
r
n
i
n
g
m
a
k
e
learning make
learningmake。
接下来要做的重复这个过程,逐步生成单词,直到遇到结束标识符停止。最后得到概率最大的那个生成序列。其概率为:
P
(
Y
∣
C
)
=
P
(
y
1
∣
C
)
P
(
y
2
∣
C
,
y
1
)
,
.
.
.
,
P
(
y
6
∣
C
,
y
1
,
y
2
,
y
3
,
y
4
,
y
5
)
P(Y|C)=P(y_1|C)P(y_2|C,y_1),...,P(y_6|C,y_1,y_2,y_3,y_4,y_5)
P(Y∣C)=P(y1∣C)P(y2∣C,y1),...,P(y6∣C,y1,y2,y3,y4,y5)
以上就是Beam search算法的思想,当beam size=1时,就变成了贪心算法。
Beam search算法也有许多改进的地方,根据最后的概率公式可知,该算法倾向于选择最短的句子,因为在这个连乘操作中,每个因子都是小于1的数,因子越多,最后的概率就越小。解决这个问题的方式,最后的概率值除以这个生成序列的单词数(记生成序列的单词数为
N
N
N),这样比较的就是每个单词的平均概率大小。
此外,连乘因子较多时,可能会超过浮点数的最小值,可以考虑取对数来缓解这个问题。
参考文献:
吴恩达-《序列模型》课程