transformers.generator_utils函数源码解析之RepetitionPenaltyLogitsProcessor

主要记录源码中解决文本生成中词组重复出现的问题,代码中有具体操作解析。

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    :class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.

    Args:
        repetition_penalty (:obj:`float`):
            The parameter for repetition penalty. 1.0 means no penalty. See `this paper
            <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
    """

    def __init__(self, penalty: float):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        self.penalty = penalty

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        #scores为cur-step的词表分布[batch,seq,vocab_size],input_ids为输入decoder的文本序列[batch,seq],则score则是获取当前已经生成文本序列的token概率
        score = torch.gather(scores, 1, input_ids) 

        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        #减少已经出现的token的概率
        score = torch.where(score < 0, score * self.penalty, score / self.penalty) 
        
        #将减少后的概率重分配到原始的cur-step词表分布中
        scores.scatter_(1, input_ids, score) 
        return scores

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值