beamsearch以及python实现----对比贪婪greedsearch和维特比算法

1.首先说beamsearch是什么

举个例子很容易说清楚:

seq2seq模型的decoder解码的时候:

1: 生成第1个词的时候,选择概率最大的2个词,假设为a,c,那么当前序列就是a,c
2:生成第2个词的时候,我们将当前序列a和c,分别与词表中的所有词进行组合,得到新的6个序列aa ab ac ca cb cc,然后从其中选择2个得分最高的,作为当前序列,假如为aa cb
3:后面会不断重复这个过程,直到遇到结束符为止。最终输出2个得分最高的序列。

简单来说,beamsearch是每次选取概率最大的beam width个词组作为结果,并将它们分别传入下一个时刻的decode阶段进行解码得到新的组合序列,在从新的序列中选取最大的beam width个词组,一直循环到结束。

2.对比greedsearch和维特比算法

greedsearch是beamsearch在beam width=1的情况下的特例。相对于greedsearch,beamsearch实际上是增加了搜索空间,但也只能做到局部最优解,不一定是全局最优解。因为考虑到seq2seq的inference阶段的搜索空间过大而导致的搜索效率降低,所以即使是一个相对的局部优解在工程上也是可接受的。

那么怎么能得到一个全局最优解呢?个人理解只能是beam width等于语料库的大小时才能找到全局最优,当然这只是理论上,不存在应用上的可能,毕竟语料库太大,也没有必要。还有一种算法是可以reach全局最优解的,维特比算法,专为全局最优而生。但他是建立在前后隐状态独立的条件下的,所以decoder不能用维特比算法的根本原因在于不独立,而不是单单是因为隐状态太大。

面试的时候,我发现更多的面试官是从维特比算法这个角度出发来理解beamsearch的,也就是说,beamsearch可以理解为维特比算法的删减版,砍掉了所有解中的一部分来取得一个局部最优的结果。而维特比算法实际上是动态规划的。

3.python实现
from math import log
from numpy import array
from numpy import argmax

# beam search
def beam_search_decoder(data, k):
    # 始终维护一个长度是k的sequence
    sequences = [[list(), 1.0]]  # [[[], 1.0]] 初始长度是1  内部列表 第一个元素是所选的index来列表,第二个元素是概率的乘积
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        # expand each current candidate
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score * -log(row[j])]
                all_candidates.append(candidate)
        # order all candidates by score
        ordered = sorted(all_candidates, key=lambda tup: tup[1])
        # select k best
        sequences = ordered[:k]  # 只取前k个
    return sequences

def greedy_decoder(data):
    # index for largest probability each row
    return [argmax(s) for s in data]

# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# decode sequence
result = beam_search_decoder(data, 3)
# print result
for seq in result:
    print(seq)

每次循环后sequence的值:

[[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]]
[[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793]]
[[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523]]
[[[4, 0, 4, 0], 0.23083509858308343], [[4, 0, 3, 0], 0.3051474021030719], [[4, 1, 4, 0], 0.3051474021030719]]
[[[4, 0, 4, 0, 4], 0.1600026977571413], [[4, 0, 3, 0, 4], 0.21151206142293622], [[4, 1, 4, 0, 4], 0.21151206142293622]]
[[[4, 0, 4, 0, 4, 0], 0.11090541883234757], [[4, 0, 4, 0, 4, 1], 0.1466089890297302], [[4, 0, 3, 0, 4, 0], 0.1466089890297302]]
[[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158], [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145], [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]]
[[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184], [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441], [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441]]
[[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901], [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468], [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]]
[[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108], [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397], [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]

看一下最终结果:

[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108]
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397]
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值