深度学习:标签平滑(Label Smoothing Regularization)

1.标签平滑的作用—防止过拟合

在进行多分类时,很多时候采用one-hot标签进行计算交叉熵损失,而单纯的交叉熵损失时,只考虑到了正确标签的位置的损失,而忽略了错误标签位置的损失。这样导致模型可能会在训练集上拟合的非常好,但由于其错误标签位置的损失没有计算,导致预测的时候,预测错误的概率比较大,也就是常说的过拟合。
标签平滑可以在一定程度上防止过拟合。

2. 传统的交叉熵损失计算

Step1: softmax多分类
P i = e z i ∑ i = 1 n e z i P_i = { e^{z_i} \over {\sum_{i=1}^{n} e^{z_i}} } Pi=i=1neziezi
其中, p i p_i pi为当前样本属于类别 i i i的概率, z i z_i zi 指当前样本的对应类别 i i i l o g i t logit logit, n表示样本的总列别数。
Step2: 交叉熵损失计算公式:
c r o s s L o s s = − 1 M ∑ m = 1 M ∑ i = 1 n y i l o g p i crossLoss = - {1 \over M} {\sum_{m=1}^M {\sum_{i=1}^n}} y_ilog{p_i} crossLoss=M1m=1Mi=1nyilogpi
其中, M M M表示样本综述。
实例:
假设一批样本,样本类别的总数n=5, 其中一个样本的one-hot标签为 [ 0 , 0 , 0 , 1 , 0 ] [0,0,0,1,0] [0,0,0,1,0],假设通过模型(如全连接等)的 l o g i t logit logit进行softmax后的概率矩阵 p p p为:
p = [ 0.1 , 0.1 , 0.1 , 0.36 , 0.34 ] p = [0.1,0.1,0.1, 0.36, 0.34] p=[0.1,0.1,0.1,0.36,0.34]
将其带入到上面的公式,即可计算出单个样本的loss为:
l o s s = − ( 0 ∗ l o g 0.1 + 0 ∗ l o g 0.1 + 0 ∗ l o g 0.1 + 1 ∗ l o g 0.36 + 0 ∗ l o g 0.34 ) = − l o g 0.36 = 1.47 loss = -(0*log0.1+0*log0.1+0*log0.1+1*log0.36+0*log0.34) = -log0.36=1.47 loss=(0log0.1+0log0.1+0log0.1+1log0.36+0log0.34)=log0.36=1.47
这种传统计算交叉熵损失只考虑了正确标签位置的损失,而没有考虑错误标签的损失。下面让我们看看带有标签平滑的交叉熵损失是怎样计算的吧。

3.带有标签平滑的交叉熵损失的计算

同样是上面的例子:一批样本,样本类别的总数n=5, 其中一个样本的one-hot标签为 [ 0 , 0 , 0 , 1 , 0 ] [0,0,0,1,0] [0,0,0,1,0],假设通过模型(如全连接等)的 l o g i t logit logit进行softmax后的概率矩阵 p p p为:
p = [ 0.1 , 0.1 , 0.1 , 0.36 , 0.34 ] p = [0.1,0.1,0.1, 0.36, 0.34] p=[0.1,0.1,0.1,0.36,0.34]
设:标签的平滑因子 ϵ = 0.1 \epsilon=0.1 ϵ=0.1,平滑的计算步骤如下:
y 1 = ( 1 − ϵ ) ∗ [ 0 , 0 , 0 , 1 , 0 ] = [ 0 , 0 , 0 , 0.9 , 0 ] y1 = (1-\epsilon)*[0,0,0,1,0] = [0,0,0,0.9,0] y1=(1ϵ)[0,0,0,1,0]=[0,0,0,0.9,0]
y 2 = ϵ ∗ [ 1 , 1 , 1 , 1 , 1 ] / 5 = [ 0.1 , 0.1 , 0.1 , 0.1 , 0.1 ] / 5 = [ 0.02 , 0.02 , 0.02 , 0.02 , 0.02 ] y2 = \epsilon*[1,1,1,1,1] / 5= [0.1,0.1,0.1,0.1,0.1]/5 = [0.02, 0.02, 0.02, 0.02, 0.02] y2=ϵ[1,1,1,1,1]/5=[0.1,0.1,0.1,0.1,0.1]/5=[0.02,0.02,0.02,0.02,0.02]
y = y 1 + y 2 = [ 0.02 , 0.02 , 0.02 , 0.92 , 0.02 ] y = y1+y2 = [0.02,0.02,0.02,0.92, 0.02] y=y1+y2=[0.02,0.02,0.02,0.92,0.02]
y y y即是平滑后的新标签,然后按照传统的交叉熵损失计算步骤即可,如:
l o s s = − y ∗ l o g p = − [ 0.02 , 0.02 , 0.02 , 0.92 , 0.02 ] ∗ l o g ( [ 0.1 , 0.1 , 0.1 , 0.36 , 0.34 ] ) = 2.63 loss=-y*logp = -[0.02,0.02,0.02,0.92, 0.02] *log([0.1,0.1,0.1,0.36,0.34])=2.63 loss=ylogp=[0.02,0.02,0.02,0.92,0.02]log([0.1,0.1,0.1,0.36,0.34])=2.63

4.标签平滑与传统的交叉熵损失的比较与分析

有上面实例可以看出,带有标签平滑的损失要比传统交叉熵损失要更大。换言之,带有标签平滑的损失要想下降到传统交叉熵损失的程度,就要学习的更好,迫使模型往正确分类的方向走。

5. 标签平滑的应用场景

只要用到的是交叉熵损失(cross loss),都可以采取标签平滑处理。

6.pytorch的实现与使用

import torch
import torch.nn as nn
import torch.nn.functional as F


class CELossWithLabelSmoothing(nn.Module):
    ''' Cross Entropy Loss with label smoothing '''
    def __init__(self, label_smooth=0.1, class_num=3755):
        super().__init__()
        self.label_smooth = label_smooth
        self.class_num = class_num

    def forward(self, pred, target):
        '''
        Args:
            pred: prediction of model output    [N, M]
            target: ground truth of sampler [N]
        '''
        eps = 1e-12

        if self.label_smooth is not None:
            # cross entropy loss with label smoothing
            logprobs = F.log_softmax(pred, dim=1)  # softmax + log
            target = F.one_hot(target, self.class_num)  # 转换成one-hot

            # label smoothing
            # 实现 1
            # target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num
            # 实现 2
            # implement 2
            target = torch.clamp(target.float(), min=self.label_smooth / (self.class_num - 1),
                                 max=1.0 - self.label_smooth)
            loss = -1 * torch.sum(target * logprobs, 1)

        else:
            # standard cross entropy loss
            loss = -1. * pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred + eps).sum(dim=1))

        return loss.mean()


if __name__ == '__main__':
    loss2 = CELossWithLabelSmoothing(label_smooth=0.2, class_num=3)
    x = torch.tensor([[0.1, 8, 0.1], [0.1, 0.1, 8]], dtype=torch.float)
    y = torch.tensor([1, 2])
    print(loss2(x, y))
  • 16
    点赞
  • 52
    收藏
    觉得还不错? 一键收藏
  • 17
    评论
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值