小黑晚上要和frank meeting啦:Label Smoothing

本文探讨了如何通过标签平滑技术改进交叉熵损失函数,避免模型过拟合并提升损失函数的平滑性。作者解释了标签平滑前后概率分布的变化,并提供了核心代码实例,展示了如何在PyTorch中实现LabelSmoothingCrossEntropy模块。
摘要由CSDN通过智能技术生成

1.原理

在之前的交叉熵损失函数中,我们让模型学习到的标签类别的概率为1,其他的类别让模型学习出的概率为0,这样很容易让模型过拟合,并且使得损失函数不平滑,为了解决这个问题,我们使用标签平滑的方法。

平滑前的理想概率分布:

在这里插入图片描述

平滑后的理想概率分布:

在这里插入图片描述
经过标签平滑的方法,然后送入到交叉熵损失函数下,具体公式如下图(小黑找了半天才找到的图):
在这里插入图片描述
损失函数核心代码:

loss * self.eps / c + (1 - self.eps) * F.nll_loss(log_preds,target,reduction = self.reduction,ignore_index = self.ignore_index)

核心代码对应公式:
在这里插入图片描述
小黑突然上面两张图与本文的第二张图在概率分布上的表示出现了不一致的现象,好多博客上的写法也不太一样,为此小黑纠结了一下午,不纠结了,明白啥意思就行了,学!跑!练!干!小黑冲!!!!!!
在这里插入图片描述

2.整体代码demo:

import torch.nn as nn
import torch.nn.functional as F
import torch
class LabelSmoothingCrossEntropy(nn.Module):
    
    def __init__(self,eps = 0.1,reduction = 'mean',ignore_index = -100):
        super(LabelSmoothingCrossEntropy,self).__init__()
        self.eps = eps
        self.reduction = reduction
        self.ignore_index = ignore_index
    
    def forward(self,output,target):
        # output:[num,num_tags]
        # target:[num]
        
        c = output.size()[-1]
        # log_preds:[num,num_tags]
        log_preds = F.log_softmax(output,dim = -1)
        if self.reduction == 'sum':
            loss = -log_preds.sum()
        else:
            loss = -log_preds.sum()
            if self.reduction == 'mean':
                loss = loss.mean()
        return loss * self.eps / c + (1 - self.eps) * F.nll_loss(log_preds,target,reduction = self.reduction,ignore_index = self.ignore_index)

reduction = 'sum'
loss = LabelSmoothingCrossEntropy(reduction = reduction)
pred = torch.randn([4,10])
target = torch.ones([4]).long()
print('smooth loss:',loss(pred,target))
输出:

smooth loss: tensor(10.6777)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值