一、总结
【理解】seq2seq中的beam search算法过程:https://zhuanlan.zhihu.com/p/28048246
【理解】如何通俗的理解beam search?:https://zhuanlan.zhihu.com/p/82829880
二、代码
【实现】
How to Implement a Beam Search Decoder for Natural Language Processing:
https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/
from math import log
from numpy import array
from numpy import argmax
# beam search
def beam_search_decoder(data, k):
sequences = [[list(), 0.0]]
# walk over each step in sequence
for row in data:
all_candidates = list()
# expand each current candidate
for i in range(len(sequences)):
# seq是一个路径列表,score是这个列表对应的分数
seq, score = sequences[i]
for j in range(len(row)):
# row[j]是这个节点的得分,得分小于1,所以是个负数,所以log后用减号?
# 这里分数的更新感觉没有利用i-1位取值的不同来更新,估计只是做个模拟
candidate = [seq + [j], score - log(row[j])]
all_candidates.append(candidate)
# order all candidates by score
# 根据分数来排序
ordered = sorted(all_candidates, key=lambda tup: tup[1])
# select k best
sequences = ordered[:k]
return sequences
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# decode sequence
result = beam_search_decoder(data, 3)
# print result
for seq in result:
print(seq)
# 输出
"""
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 6.931471805599453]
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 7.154615356913663]
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 7.154615356913663]
"""