Label Smoothing介绍及其代码实现

一、标签平滑(Label Smoothing)介绍

标签平滑(Label Smoothing)的原理其实很简单,它大部分的用处用一句话总结就是:

修改数据集的标签来增加扰动,避免模型的判断过于自信从而陷入过拟合

标签平滑是一种正则化的技术,常常用在分类任务中。它的具体做法就是为数据集的标签增加扰动,它的具体做法如下所示(以K分类任务为例)。

对于K分类来说,假设一个样本 x x x属于第2类,那么实际上用来训练模型(或者说用来计算损失函数)的标签是一个独热编码,具体为[0,0,1,0], 即在位置为2处数值为1(代表属于第2类(从第0类开始计数))。此时标签平滑的具体步骤为:

  1. 定义一个小的扰动常量 ϵ \epsilon ϵ
  2. 将独热编码的标签中的0替换为 ϵ / K \epsilon/K ϵ/K
  3. 将独热编码的标签中的1替换为 1 − e p s i l o n / K 1-epsilon/K 1epsilon/K

由于在现实数据集中,并不是所有标签都是正确标注的,所以直接最大化 log ⁡ p ( y ∣ x ) \log{p}\left(y\mid{x}\right) logp(yx)(即过于自信的把其中一个候选类对应的digit置为1,将其余类的digit置为零), 反而是有害的。这种过于自信的做法不仅仅会使得模型过拟合,而且有可能拟合到错误的例子上去。

一些实验已经证明,标签平滑能够增加模型的泛化能力(Müller et al., 2020)。

二、标签平滑的代码实现

作为一种相对成熟的技术,标签平滑已经有许多开箱即用的实现,在这里我摘取CoinCheung的实现作为示例。

这个版本主要面向pytorch框架,你可以像使用pytorch中的CrossEntropyLoss一样使用它,无需任何改动。

# version 1: use torch.autograd
class LabelSmoothSoftmaxCEV1(nn.Module):
    '''
    This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
    '''

    def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
        super(LabelSmoothSoftmaxCEV1, self).__init__()
        self.lb_smooth = lb_smooth
        self.reduction = reduction
        self.lb_ignore = ignore_index
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, logits, label):
        '''
        Same usage method as nn.CrossEntropyLoss:
            >>> criteria = LabelSmoothSoftmaxCEV1()
            >>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
            >>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
            >>> loss = criteria(logits, lbs)
        '''
        # overcome ignored label
        logits = logits.float() # use fp32 to avoid nan
        with torch.no_grad():
            num_classes = logits.size(1)
            label = label.clone().detach()
            ignore = label.eq(self.lb_ignore)
            n_valid = ignore.eq(0).sum()
            label[ignore] = 0
            lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
            lb_one_hot = torch.empty_like(logits).fill_(
                lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()

        logs = self.log_softmax(logits)
        loss = -torch.sum(logs * lb_one_hot, dim=1)
        loss[ignore] = 0
        if self.reduction == 'mean':
            loss = loss.sum() / n_valid
        if self.reduction == 'sum':
            loss = loss.sum()

        return loss

三、参考资料

  1. https://paperswithcode.com/method/label-smoothing
  2. https://arxiv.org/pdf/1906.02629.pdf
  3. https://github.com/CoinCheung/pytorch-loss/blob/master/label_smooth.py
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值