PPL代码

根据PPL两种不同的计算公式,有两种不同的代码实现。

实现一:使用perplexity的对数形式:将每个位置上的概率取对数再平均

# 定义计算PPL的函数
def calculate_ppl(model, conversations):
    total_loss = 0
    
    for conversation in conversations:
        
        with torch.no_grad():
            # 计算对话的概率分布
            outputs = model(input_ids=input_ids, labels=target_ids)
            logits = outputs.logits
            loss = CrossEntropyLoss(reduction='sum')(logits.view(-1, logits.shape[-1]), target_ids.view(-1))
        
        total_loss += loss.item()
  
    avg_loss = total_loss / len(conversations)
    ppl = torch.exp(avg_loss)

实现二:

针对文本中的词预测任务来说,离散概率分布p的困惑度由下式给出,其中H(p) 是该分布的熵,x遍历事件空间。概率分布的perplexity:

 

代码:

#nn.NLLLoss负对数似然函数作为损失函数
self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
loss = self.criterion_ppl(
            logit.contiguous().view(-1, logit.size(-1)),
            dec_batch.contiguous().view(-1),
        )
ppl =  math.exp(min(loss.item(), 100))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值