CrossEntropyLoss改进


前言

CrossEntropyLoss 是分类任务中经常使用的损失函数,但是在某些情况下,其优化效果并不是很好,本文介绍了最近出现的对CrossEntropyLoss进行改进的新损失函数

一、CrossEntropyLoss

公式:
在这里插入图片描述
上图是pytorch版实现的CrossEntropyLoss,可以看出其主要作用是优化了正例对应的logits(logits介绍见上一篇博文)并使其无限大与其他类别的logits,这种过强的要求可能使得模型难以训练至收敛,因而出现了LabelSM版本的CrossEntropy,以及Sparse Softmax

顺带提一句,pytorch版本的CrossEntropyLoss是对dim=1进行的计算,
因而我们需要把各个类别的logits放到dim=1上来

二、SmoothCrossEntropy

公式:
SmoothCrossEntropy对应的公式为:
在这里插入图片描述
优势:
当 label smoothing 的 loss 函数为 cross entropy 时,如果 loss 取得极值点,则正确类和错误类的 logit 会保持一个常数距离,且正确类和所有错误类的 logits 相差的常数是一样的,都是 log ⁡ ( K − ( K − 1 ) α α ) \log(\frac{K-(K-1)\alpha}{\alpha}) log(αK(K1)α)
证明见:知乎

code:

class SmoothCrossEntropy(nn.Module):
    """
    loss = SmoothCrossEntropy()
    input = torch.randn(3, 5, requires_grad=True)
    target = torch.empty(3, dtype=torch.long).random_(5)
    output = loss(input, target)
    """
    def __init__(self, alpha=0.1):
        super(SmoothCrossEntropy, self).__init__()
        self.alpha = alpha

    def forward(self, logits, labels):
        num_classes = logits.shape[-1]
        alpha_div_k = self.alpha / num_classes
        target_probs = F.one_hot(labels, num_classes=num_classes).float() * \
            (1. - self.alpha) + alpha_div_k
        loss = -(target_probs * torch.log_softmax(logits, dim=-1)).sum(dim=-1)
        return loss.mean()

代码如下(示例):

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import  ssl
ssl._create_default_https_context = ssl._create_unverified_context

三、Sparse Softmax

公式:
在这里插入图片描述
优势:
这是苏神在CAIL2020中提出的一个类别数过多的预测问题损失函数,我们只需要优化前topK项,使得 s t s_t st 大于topk即可,不必要大于最小的 log ⁡ ( n − 1 ) \log(n-1) log(n1)
,只需大于topk中最小的 l o g ( k ) log(k) log(k)即可,可以防止过度训练
证明
pytoch版本:

def Sparse_Softmax(predictions, token_type_id, input_ids, vocab_size):

    predictions = predictions[:, :-1].contiguous()
    target_mask = token_type_id[:, 1:].contiguous()
    """
       target_mask : 句子a部分和pad部分全为0, 而句子b部分为1
    """
    predictions = predictions.view(-1, vocab_size)
    labels = input_ids[:, 1:].contiguous()
    labels = labels.view(-1)
    target_mask = target_mask.view(-1).float()
    # 正loss
    pos_loss = predictions[list(range(predictions.shape[0])), labels]
    # 负loss
    y_pred = torch.topk(predictions, k=args.k_sparse)[0]
    neg_loss = torch.logsumexp(y_pred, dim=-1)

    loss = neg_loss - pos_loss
    return (loss * target_mask).sum() / target_mask.sum()  ## 通过mask 取消 pad 和句子a部分预测的影响

L-Softmax、SM-Softmax、AM-Softmax待补充

  • 3
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值