标签平滑(label smoothing)

目录

1.标签平滑主要解决什么问题?

2.标签平滑是怎么操作的?

3.标签平滑公式

4.代码实现


标签平滑(label smoothing)出自GoogleNet v3

关于one-hot编码的详细知识请见:One-hot编码

1.标签平滑主要解决什么问题?

传统的one-hot编码会带来的问题无法保证模型的泛化能力,使网络过于自信会导致过拟合。
全概率和0概率鼓励所属类别和其他类别之间的差距尽可能加大,而由梯度有界可知,这种情况很难adapt。会造成模型过于相信预测的类别。而标签平滑可以缓解这个问题。
 

2.标签平滑是怎么操作的?

标签平滑是把one-hot中概率为1的那一项进行衰减,避免过度自信,衰减的那部分的自信被平均分到每一个类别中

例如:

一个4分类任务,label = (0,1,0,0)

labeling smoothing = (\frac{0.001}{4},1-0.001+\frac{0.001}{4}\frac{0.001}{4}\frac{0.001}{4})=(0.00025,0.99925,0.00025,0.00025)

 其中,概率加起来等于1。

3.标签平滑公式

交叉熵(Cross Entropy):H(q,p)=-\sum_{k=1}^{k}log(p_k)q_k

其中,q为标签值,p为预测结果,k为类别。即q为one-hot编码结果。

labeling smothing:将q进行标签平滑变为q',让模型输出的p分布去逼近q'

q'(k|x)=(1-\varepsilon )\delta _{k,y} +\varepsilon u(k),其中u(k)为一个概率分布,这里采用均匀分布u(k)=\frac{1}{k}),则得到q'(k|x)=(1-\varepsilon )\delta _{k,y} +\frac{\varepsilon }{k}

        其中,\delta _{k,y}为原分布q, ϵ ∈(0,1)是一个超参数。

        由以上公式可以看出,这种方式使label有 ϵ  概率来自于均匀分布 1−ϵ 概率来自于原分布。这就相当于在原label上增加噪声,让模型的预测值不要过度集中于概率较高的类别,把一些概率放在概率较低的类别。

故标签平滑后的交叉熵损失函数为:H(q',p)=-\sum_{k=1}^{k}logp(k)q'(k)=(1-\varepsilon )H(q,p)+\varepsilon H(u,p)

那这个公式是怎么得来的呢?

将q'(k|x)带入交叉熵损失函数:

H(q',p)=-\sum_{k=1}^{k}log(p_k)q'_k

=-\sum_{k=1}^{k}log(p_k)[(1-\varepsilon )\delta _{k,y}+\frac{\varepsilon }{k}]

=-\sum_{k=1}^{k}log(p_k)(1-\varepsilon )\delta _{k,y}+[-\sum_{k=1}^{k}log(p_k)\frac{\varepsilon }{k}]

=(1-\varepsilon )*[-\sum_{k=1}^{k}log(p_k)\delta _{k,y}]+\varepsilon *[-\sum_{k=1}^{k}log(p_k)\frac{1}{k}]

=(1-\varepsilon )*H(q,p)+\varepsilon *H(u,p)

这样就得到了标签平滑公式。

4.代码实现

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, eps=0.1, reduction='mean', ignore_index=-100):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.eps = eps
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, output, target):
        c = output.size()[-1]
        log_pred = torch.log_softmax(output, dim=-1)
        if self.reduction == 'sum':
            loss = -log_pred.sum()
        else:
            loss = -log_pred.sum(dim=-1)
            if self.reduction == 'mean':
                loss = loss.mean()


        return loss * self.eps / c + (1 - self.eps) * torch.nn.functional.nll_loss(log_pred, target,
                                                                                   reduction=self.reduction,
                                                                                   ignore_index=self.ignore_index)

  • 1
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Billie使劲学

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值