1.首先说beamsearch是什么
举个例子很容易说清楚:
seq2seq模型的decoder解码的时候:
1: 生成第1个词的时候,选择概率最大的2个词,假设为a,c,那么当前序列就是a,c
2:生成第2个词的时候,我们将当前序列a和c,分别与词表中的所有词进行组合,得到新的6个序列aa ab ac ca cb cc,然后从其中选择2个得分最高的,作为当前序列,假如为aa cb
3:后面会不断重复这个过程,直到遇到结束符为止。最终输出2个得分最高的序列。
简单来说,beamsearch是每次选取概率最大的beam width个词组作为结果,并将它们分别传入下一个时刻的decode阶段进行解码得到新的组合序列,在从新的序列中选取最大的beam width个词组,一直循环到结束。
2.对比greedsearch和维特比算法
greedsearch是beamsearch在beam width=1的情况下的特例。相对于greedsearch,beamsearch实际上是增加了搜索空间,但也只能做到局部最优解,不一定是全局最优解。因为考虑到seq2seq的inference阶段的搜索空间过大而导致的搜索效率降低,所以即使是一个相对的局部优解在工程上也是可接受的。
那么怎么能得到一个全局最优解呢?个人理解只能是beam width等于语料库的大小时才能找到全局最优,当然这只是理论上,不存在应用上的可能,毕竟语料库太大,也没有必要。还有一种算法是可以reach全局最优解的,维特比算法,专为全局最优而生。但他是建立在前后隐状态独立的条件下的,所以decoder不能用维特比算法的根本原因在于不独立,而不是单单是因为隐状态太大。
面试的时候,我发现更多的面试官是从维特比算法这个角度出发来理解beamsearch的,也就是说,beamsearch可以理解为维特比算法的删减版,砍掉了所有解中的一部分来取得一个局部最优的结果。而维特比算法实际上是动态规划的。
3.python实现
from math import log
from numpy import array
from numpy import argmax
# beam search
def beam_search_decoder(data, k):
# 始终维护一个长度是k的sequence
sequences = [[list(), 1.0]] # [[[], 1.0]] 初始长度是1 内部列表 第一个元素是所选的index来列表,第二个元素是概率的乘积
# walk over each step in sequence
for row in data:
all_candidates = list()
# expand each current candidate
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score * -log(row[j])]
all_candidates.append(candidate)
# order all candidates by score
ordered = sorted(all_candidates, key=lambda tup: tup[1])
# select k best
sequences = ordered[:k] # 只取前k个
return sequences
def greedy_decoder(data):
# index for largest probability each row
return [argmax(s) for s in data]
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# decode sequence
result = beam_search_decoder(data, 3)
# print result
for seq in result:
print(seq)
每次循环后sequence的值:
[[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]]
[[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793]]
[[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523]]
[[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 3, 0], 0.3051474021030719], [[4, 1, 4, 0], 0.3051474021030719]]
[[[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 4], 0.21151206142293622]]
[[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 3, 0, 4, 0], 0.1466089890297302]]
[[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]]
[[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441]]
[[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]]
[[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]
看一下最终结果:
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108]
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397]
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]