解码策略

本文深入探讨了自然语言处理中的文本生成技术,包括贪心策略、束搜索、随机采样及其优化方法,如top-k和top-p采样。文章通过代码解析展示了如何在实际操作中应用这些技术,并讨论了它们存在的问题,如生成内容的连贯性和词汇选择的合理性。此外,还介绍了重复惩罚和重复词去除策略,以提高生成文本的质量。
摘要由CSDN通过智能技术生成

贪心

核心思想: 每一步取当前最可能的结果,作为最终结果
具体方法:
获得新生成的词是vocab中各个词的概率,取argmax作为需要生成的词向量索引,继而生成后一个词

beamsearch

核心思想: beam search尝试在广度优先基础上进行进行搜索空间的优化(类似于剪枝)达到减少内存消耗的目的
具体方法:
在decoding的每个步骤,我们都保留着 top K 个可能的候选单词,然后到了下一个步骤的时候,我们对这 K 个单词都做下一步 decoding,分别选出 top K,然后对这 K^2 个候选句子再挑选出 top K 个句子。以此类推一直到 decoding 结束为止。当然 Beam Search 本质上也是一个 greedy decoding 的方法,所以我们无法保证自己一定可以得到最好的 decoding 结果

随机sampling

我们可以在生成文本的时候引入一些随机性。例如现在语言模型告诉我们下一个单词在整个单词表上的概率分布是 p = (p_1, p_2, … p_|V|),那么我们就可以按照这个概率分布进行随机采样,然后决定下一个单词生成什么。采样相对于greedy方法的好处是,我们生成的文字开始有了一些随机性,不会总是生成很机械的回复了
可以设置温度T,T越大,softmax后的结果越平,随机性越大,T越小,分布越尖锐,随机性越小
在这里插入图片描述

存在的问题

①生成的话容易不连贯,上下文比较矛盾。
②容易生成奇怪的话,出现罕见词。

top-k sampling

取概率最大的K个词,之后对这K个词概率归一化之后再进行sampling,但K的大小不太好选,因为不同的句子,概率分布的变化有很大的区别,有的时候比较平,有的时候比较集中,分布均衡时,K小了容易丢失优质的词,分布集中时,K大了容易引入奇怪的词,就和随机采样没什么区别了。
在这里插入图片描述

top-p(nucleus) sampling核采样

The Curious Case of Neural Text Degeneration
https://arxiv.org/abs/1904.09751

好处:不需要手动的选取K,作者选取p为0.95
对当前的所有词的概率按照从大到小开始累加,当累加的值大于阈值P的时候,后面小的概率词就不使用,对前面的词再进行sampling,如设置阈值p为0.95,则相当于对左上选用top 4,右上选用top 2

参考:https://zhuanlan.zhihu.com/p/115076102
代码解析
其实上述各种采样方式在HuggingFace的库里都已经实现了(感动!),我们来看一下代码。

先看top-k和top-p采样

# 代码输入的是logits,而且考虑很周全(我感觉漏了考虑k和p都给了的情况,这应该是不合适的)
# 巧妙地使用了torch.cumsum
# 避免了一个词都选不出来的尴尬情况
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size, vocabulary size)
            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
            Make sure we keep at least min_tokens_to_keep per batch example in the output
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

再看看重复惩罚

# 输入的同样是logits(lprobs)
# 同时输入了之前出现过的词以及惩罚系数(大于1的)
# 考虑到了logit是正和负时处理方式应该不一样
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
        """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
        for i in range(batch_size * num_beams):
            for previous_token in set(prev_output_tokens[i].tolist()):
                # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
                if lprobs[i, previous_token] < 0:
                    lprobs[i, previous_token] *= repetition_penalty
                else:
                    lprobs[i, previous_token] /= repetition_penalty

最后是重复词去除

# 这个函数将会返回一个不可使用的词表
# 生成n-gram的巧妙方式大家可以借鉴一下
# 下面是一个3-gram的例子
# a = [1,2,3,4,5]
# for ngram in zip(*[a[i:] for i in range(3)]):
#    print(ngram)
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
    # Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < no_repeat_ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].numpy().tolist()
        generated_ngram = generated_ngrams[idx]
        # 就是这巧妙的一句
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
        start_idx = cur_len + 1 - no_repeat_ngram_size
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
        return generated_ngrams[hypo_idx].get(ngram_idx, [])

    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens
if do_sample:
    # 这是今天的采样方式
    _scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
    # Top-p/top-k filtering,这一步重建了候选集
    _scores = top_k_top_p_filtering(
        _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
    )  # (batch_size * num_beams, vocab_size)
    # re-organize to group the beam together to sample from all beam_idxs
    _scores = _scores.contiguous().view(
        batch_size, num_beams * vocab_size
    )  # (batch_size, num_beams * vocab_size)

    # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
    probs = F.softmax(_scores, dim=-1)
    # 采样
    next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
    # Compute next scores
    next_scores = torch.gather(_scores, -1, next_tokens)  # (batch_size, num_beams * 2)
    # sort the sampled vector to make sure that the first num_beams samples are the best
    next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
    next_tokens = torch.gather(next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)
else:
    # 这是昨天的beam search方式
    # 直接将log概率相加求条件概率
    next_scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)

    # re-organize to group the beam together (we are keeping top hypothesis accross beams)
    next_scores = next_scores.view(
        batch_size, num_beams * vocab_size
    )  # (batch_size, num_beams * vocab_size)

    next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值