class Beam(object):
def __init__(self, size,sos,eos):
self.size = size
self.tt = torch.cuda
self.scores = self.tt.FloatTensor(size).zero_()
# 大小为[beam_size],记录当前每个beam的分数总和
self.prevKs = []
# 记录每一步选取的是第几个beam,便于最后回溯生成结果
self.nextYs = [self.tt.LongTensor(size)
.fill_(0)]
# nextYs: [seq_len=1, beam_size],随着预测过程seq_len逐渐增加,表示每一步的输出结果
# seq_len即为time_step
self.nextYs[0][0] = sos
# Has EOS topped the beam yet.
self._eos = eos
self.eosTop = False
# Time and k pair for finished.
self.finished = []
def getCurrentState(self):
batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
# batch: [beam_size, seq_len],用于加入到下一次模型的输入中。
return batch
def getCurrentOrigin(self):
"Get the backpointers for the current timestep."
return self.prevKs[-1]
def advance(self, wordLk):
'''
更新beam中的信息
wordLk: [beam_size, vocab_size],上一个时间节点每个beam的模型预测结果,需要用LogSoftMax进行归一化
'''
numWords = wordLk.size(1)
# numWords: vocab_size
if len(self.prevKs) > 0:
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
# scores: [beam_size]
# wordLk是当前的分数,scores是之前的分数,加起来得到beamLk: [beam_size, vocab_size]
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
beamLk[i] = -1e20
# 把第i个beam的概率全部设置为负无穷
else:
beamLk = wordLk[0]
# beamLk: [vocab_size] 刚开始只有第一个beam
flatBeamLk = beamLk.view(-1) # beamlLk展开
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) # topk个最好分数
self.scores = bestScores
# scores: [beam_size]
prevK = bestScoresId // numWords
# prevK: [beam_size]
self.prevKs.append(prevK)
# prevKs: [time_step, beam_size] 记录了每个时间节点的结果来自于第几个beam
self.nextYs.append((bestScoresId - prevK * numWords))
# nextYs: [seq_len, beam_size] 记录了每个事件节点选取的id, seq_len即time_step
# 对nextYs的最后一个时间节点进行遍历,检查是否出现了结束符
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
s = self.scores[i]
self.finished.append((s, len(self.nextYs) - 1, i))
# i 表示第几个beam
# 若出现结束符,将(总分数,句子长度,beam的id)三元组加入到finished列表中
# finished列表中存的是已经结束的beam的信息
# End condition is when top-of-beam is EOS and no global score.
if self.nextYs[-1][0] == self._eos:
# 当nextYs中最后一个时间点的第一个id为结束符时,将eosTop设置为True
self.eosTop = True
def done(self):
# 当eosTop为True且已经结束的beam数大于等于beam_size的时候就结束。
return self.eosTop and len(self.finished) >=self.size
def getFinal(self):
if len(self.finished) == 0:
# 这里的情况就是所有beam的句子长度都达到了max_length但没有任何一个产生了结束符
self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
# 这种情况下就手动将第0个beam设置为已经结束
self.finished.sort(key=lambda a: -a[0])
# 将finished按beam的分数由大到小排序
if len(self.finished) != self.size:
# 将没有结束的句子也按(分数,长度,beam_id)三元组的形势加入到finished中
unfinished=[]
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] != self._eos:
s = self.scores[i]
unfinished.append((s, len(self.nextYs) - 1, i))
unfinished.sort(key=lambda a: -a[0])
self.finished+=unfinished[:self.size-len(self.finished)]
# 已经结束的beam排在未结束的句子前面
return self.finished[:self.size]
def getHyp(self, beam_res):
"""
回溯,生成结果
"""
# beam_res 传入的就是finished列表,由get_final得到
hyps=[]
for _,timestep, k in beam_res:
# k是指该结果来自于第几个beam
hyp = []
for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
# prevKs: [time_step, beam_size] 记录了每个时间节点的结果来自于第几个beam
hyp.append(self.nextYs[j+1][k])
# nextYs: [time_step, beam_size] 记录了每个beam的每一步选择,将该id加入到hyp中
k = self.prevKs[j][k]
# k为结果来自于第几个beam
hyps.append(hyp[::-1])# hyp反过来加入到hyps中
# 最后得到的hyps:[beam_size, ~]列表,~即长度不一,是每一个beam的预测结果,按分数大小排列
return hyps
def buildTargetTokens(self, preds):
# preds即为getHyp产生的hyps,记录了每个beam产生的结果,按分数大小排列
# 这个函数的目的是截断eos之后的结果
sentence=[]
for pred in preds:
tokens = []
for tok in pred:
if tok==self._eos:
break
tokens.append(tok)
sentence.append(tokens)
return sentence
Beam Search源码理解
最新推荐文章于 2024-04-08 20:52:01 发布