在语音识别、机器翻译等问题中,在得到单词或者字符的概率矩阵以后,需要通过deocder产生最有可能的单词序列,而beam search是decoder中常用的一种启发式算法。
用multiprocessing实现了一个多线程的束搜索,multithread_bs(data,length,k,worker)输入data的形状是[batch,max_length,char_num],length为每个样本的实际长度,形状为[batch,],k是束搜索的宽度beam_size,worker是线程数。
# -*- coding: utf-8 -*-
# @Author: huangneng
# @Date: 2019-05-30 08:48:31
# @Last Modified by: hn
# @Last Modified time: 2019-05-30 17:23:12
import numpy as np
from scipy.special import softmax
import time
from math import log,ceil
from multiprocessing import Pool
from itertools import repeat
def beam_search_decoder(data,length, k):
sequences = [[list(), 1.0]]
# walk over each step in sequence
for t in range(length):
row = data[t]
all_candidates = list()
# expand each current candidate
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score * -log(row[j])]
all_candidates.append(candidate)
# order all candidates by score
ordered = sorted(all_candidates, key=lambda tup:tup[1])
# select k best
sequences = ordered[:k]
return sequences[0][0]
def multithread_bs(data,length,k,worker=4):
pool = Pool(worker)
data = [v for v in data]
length = [v for v in length]
decode_args = zip(data,length,repeat(k))
result = pool.starmap(beam_search_decoder,decode_args)
return result
if __name__=='__main__':
data = np.random.randn(5000,512,6)
length = np.ones(5000,dtype=np.int)*512
data = softmax(data,-1)
start = time.time()
# single thread
# beam_search_decoder(data,length,3)
# multi-threads
print(multithread_bs(data,length,3))
end=time.time()
print(end-start)