束搜索
在seq2seq中我们使用贪心算法计算句子输出的概率,输出序列
但贪心可能并不是最好的,比如下面
那什么样最好呢,肯定是穷举,吧所有的顺序找出来,一定能找到最优解,,词表大小是n,时间步长T,那么时间复杂度是: O ( n T ) O(n^T) O(nT) 实际上是 n 1 + n 2 + . . . + n T n^1+n^2+...+n^T n1+n2+...+nT在第i个时间步需要在大小为 n i n^i ni的一层中遍历一遍寻找最大的,穷举法最后一层比其他层多几个幂级,所以前面省略了。
穷举法开销太大,所以折中一下提出了集束搜索。
- 保存最好的k个候选
- 每个时间步在kn个选项中选出最好的k个
注意,并不是在一个分支里再选出k个
,是上一层选出k个最好的分支,当前层总共有了kn个选项,在这kn个中选择k个最好的。 也就是说除了第一个时间步有n个,剩下每一个时间步都是kn个,时间复杂度是
O
(
k
n
T
)
O(knT)
O(knT),每层都是在kn个大小的序列中遍历找最好的,共T层
思考这两个问题看看是不是理解了?
- k=1时是贪心吗? ——是
- k=n时是穷举吗? ——不是
每个候选的最终得分:
1
L
α
l
o
g
p
(
y
1
,
.
.
.
,
y
L
)
=
1
L
α
∑
t
′
=
1
L
l
o
g
p
(
y
t
′
∣
y
1
,
.
.
.
,
y
t
′
)
\frac{1}{L^\alpha}log\ p(y_1,...,y_L) = \frac{1}{L^\alpha}\displaystyle\sum_{t^{'}=1}^Llog\ p(y_{t^{'}}|y_1,...,y_{t^{'}})
Lα1log p(y1,...,yL)=Lα1t′=1∑Llog p(yt′∣y1,...,yt′)
通常 α = 0.75 \alpha=0.75 α=0.75
因为越长的句子,经过累乘概率越小取log之后是个负数,得分比较低,所以前面乘 1 L α \frac{1}{L^\alpha} Lα1,相当于给长句子一点补偿。