Beam Search源码理解

class Beam(object):
    def __init__(self, size,sos,eos):
        self.size = size
        self.tt = torch.cuda
 
        self.scores = self.tt.FloatTensor(size).zero_()
        # 大小为[beam_size],记录当前每个beam的分数总和
        
        self.prevKs = []
        # 记录每一步选取的是第几个beam,便于最后回溯生成结果
        
        self.nextYs = [self.tt.LongTensor(size)
                       .fill_(0)]
        # nextYs: [seq_len=1, beam_size],随着预测过程seq_len逐渐增加,表示每一步的输出结果
        # seq_len即为time_step
        
        self.nextYs[0][0] = sos
        
        # Has EOS topped the beam yet.
        self._eos = eos
        self.eosTop = False
        # Time and k pair for finished.
        self.finished = []
 
    def getCurrentState(self):
        batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
        # batch: [beam_size, seq_len],用于加入到下一次模型的输入中。
        return batch
 
    def getCurrentOrigin(self):
        "Get the backpointers for the current timestep."
        return self.prevKs[-1]
 
    def advance(self, wordLk):
        '''
        更新beam中的信息
        wordLk: [beam_size, vocab_size],上一个时间节点每个beam的模型预测结果,需要用LogSoftMax进行归一化
        '''
 
        numWords = wordLk.size(1)
        # numWords: vocab_size
        
 
        if len(self.prevKs) > 0:
            beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
            # scores: [beam_size]
            # wordLk是当前的分数,scores是之前的分数,加起来得到beamLk: [beam_size, vocab_size]
            
 
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] == self._eos:
                    beamLk[i] = -1e20
                    # 把第i个beam的概率全部设置为负无穷
        else:
            beamLk = wordLk[0]
            # beamLk: [vocab_size] 刚开始只有第一个beam
        
        flatBeamLk = beamLk.view(-1) # beamlLk展开
        bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) # topk个最好分数
        
        self.scores = bestScores
        # scores: [beam_size]
 
        prevK = bestScoresId // numWords
        # prevK: [beam_size]
        self.prevKs.append(prevK)
        # prevKs: [time_step, beam_size] 记录了每个时间节点的结果来自于第几个beam
        self.nextYs.append((bestScoresId - prevK * numWords))
        # nextYs: [seq_len, beam_size] 记录了每个事件节点选取的id, seq_len即time_step
        
        # 对nextYs的最后一个时间节点进行遍历,检查是否出现了结束符
        for i in range(self.nextYs[-1].size(0)):
            if self.nextYs[-1][i] == self._eos:
                s = self.scores[i]
                self.finished.append((s, len(self.nextYs) - 1, i))
                # i 表示第几个beam
                # 若出现结束符,将(总分数,句子长度,beam的id)三元组加入到finished列表中
                # finished列表中存的是已经结束的beam的信息
 
        # End condition is when top-of-beam is EOS and no global score.
        if self.nextYs[-1][0] == self._eos:
            # 当nextYs中最后一个时间点的第一个id为结束符时,将eosTop设置为True
            self.eosTop = True
 
    def done(self):
        # 当eosTop为True且已经结束的beam数大于等于beam_size的时候就结束。
        return self.eosTop and len(self.finished) >=self.size
 
    def getFinal(self):
        
        if len(self.finished) == 0:
            # 这里的情况就是所有beam的句子长度都达到了max_length但没有任何一个产生了结束符
            self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
            # 这种情况下就手动将第0个beam设置为已经结束
        self.finished.sort(key=lambda a: -a[0])
        # 将finished按beam的分数由大到小排序
        if len(self.finished) != self.size:
            # 将没有结束的句子也按(分数,长度,beam_id)三元组的形势加入到finished中
            unfinished=[]
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] != self._eos:
                    s = self.scores[i]
                    unfinished.append((s, len(self.nextYs) - 1, i)) 
            unfinished.sort(key=lambda a: -a[0])
            self.finished+=unfinished[:self.size-len(self.finished)]
        # 已经结束的beam排在未结束的句子前面
        return self.finished[:self.size]
 
    def getHyp(self, beam_res):
        """
        回溯,生成结果
        """
        
        # beam_res 传入的就是finished列表,由get_final得到
        hyps=[]
        for _,timestep, k in beam_res:
            # k是指该结果来自于第几个beam
            hyp = []
            for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
                # prevKs: [time_step, beam_size] 记录了每个时间节点的结果来自于第几个beam
                hyp.append(self.nextYs[j+1][k])
                # nextYs: [time_step, beam_size] 记录了每个beam的每一步选择,将该id加入到hyp中
                k = self.prevKs[j][k]
                # k为结果来自于第几个beam
            hyps.append(hyp[::-1])# hyp反过来加入到hyps中
        # 最后得到的hyps:[beam_size, ~]列表,~即长度不一,是每一个beam的预测结果,按分数大小排列
        return hyps
    
    def buildTargetTokens(self, preds):
        # preds即为getHyp产生的hyps,记录了每个beam产生的结果,按分数大小排列
        # 这个函数的目的是截断eos之后的结果
        sentence=[]
        for pred in preds:
            tokens = []
            for tok in pred:
                if tok==self._eos:
                    break
                tokens.append(tok)
            sentence.append(tokens)
        return sentence

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值