BeamSearch算法原理及代码解析

1.算法原理

beam search有一个超参数beam_size,设为 k 。第一个时间步长,选取当前条件概率最大的 k 个词,当做候选输出序列的第一个词。之后的每个时间步长,基于上个步长的输出序列,挑选出所有组合中条件概率最大的 k 个,作为该时间步长下的候选输出序列。始终保持 k 个候选。最后从k 个候选中挑出最优的。

2.中心思想

假设有n句话,每句话的长度为T。encoder的输出shape为(n, T, hidden_dim),扩展成(n*beam_size, T, hidden_dim)。decoder第一次输入shape为(n, 1),扩展到(n*beam_size, 1)。经过一次解码,输出得分的shape为(n*beam_size, vocab_size),路径得分log_prob的shape为(n*beam_size, 1),两者相加得到当前帧的路径得分。reshape到(n, beam_size*vocab_size),取topk(beam_size),得到排序后的索引(n, beam_size),索引除以vocab_size,得到的是每句话的beam_id,用来获取当前路径前一个字;对vocab_size取余,得到的是每句话的token_id,用来获取当前路径下一次字。

3.代码解析

def beam_search():
    k_prev_words = torch.full((k, 1), SOS_TOKEN, dtype=torch.long) # (k, 1)
    # 此时输出序列中只有sos token
    seqs = k_prev_words #当前路径(k, 1)
    # 初始化scores向量为0
    top_k_scores = torch.zeros(k, 1)
    complete_seqs = [] #已完成序列
    complete_seqs_scores = [] #已完成序列的得分
    step = 1
    hidden = torch.zeros(1, k, hidden_size) # encoder的输出: (1, k, hidden_size)
    while True:
        outputs, hidden = decoder(k_prev_words, hidden) # outputs: (k, seq_len, vocab_size)
        next_token_logits = outputs[:,-1,:] # (k, vocab_size)
        if step == 1:
        # 因为最开始解码的时候只有一个结点<sos>,所以只需要取其中一个结点计算topk
          top_k_scores, top_k_words = next_token_logits[0].topk(k, dim=0, largest=True, sorted=True)
        else:
        # 此时要先展开再计算topk,如上图所示。
        # top_k_scores: (k) top_k_words: (k)
          top_k_scores, top_k_words = next_token_logits.view(-1).topk(k, 0, True, True)
        prev_word_inds = top_k_words / vocab_size  # (k)  实际是beam_id,哪个beam就是哪条最优路径
        next_word_inds = top_k_words % vocab_size  # (k)  实际是token_id,在单词表中的index
        # seqs: (k, step) ==> (k, step+1)
        seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)
  
        # 当前输出的单词不是eos的有哪些(输出其在next_wod_inds中的位置, 实际是beam_id)
        incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                                next_word != vocab['<eos>']]
        # 输出已经遇到eos的句子的beam id(即seqs中的句子索引)
        complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

        if len(complete_inds) > 0:
            complete_seqs.extend(seqs[complete_inds].tolist()) # 加入句子
            complete_seqs_scores.extend(top_k_scores[complete_inds]) # 加入句子对应的累加log_prob
        # 减掉已经完成的句子的数量,更新k, 下次就不用执行那么多topk了,因为若干句子已经被解码出来了
        k -= len(complete_inds) 
  
        if k == 0: # 完成
           break
  
        # 更新下一次迭代数据, 仅专注于那些还没完成的句子 
        seqs = seqs[incomplete_inds]
        hidden = hidden[prev_word_inds[incomplete_inds]]
        top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)   #(s, 1) s < k
        k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) #(s, 1) s < k
  
        if step > max_length: # decode太长后,直接break掉
            break
        step += 1
    i = complete_seqs_scores.index(max(complete_seqs_scores)) # 寻找score最大的序列
    # 有些许问题,在训练初期一直碰不到eos时,此时complete_seqs为空
    seq = complete_seqs[i] 
 
    return seq

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值