label smooth标签平滑的理解

      今天我们来聊一聊label smooth这个tricks,标签平滑已经成为众所周知的机器学习或者说深度学习的正则化技巧。标签平滑——label smooth regularization作为一种简单的正则化技巧,它能提高分类任务中模型的泛化性能和准确率,缓解数据分布不平衡的问题,同时在模型蒸馏中可见它的身影。在最近的2个月3个NLP算法比赛的实战中,label smooth也作为一种炼丹术被我用来提高比赛的成绩。那为啥label smooth有效呢?怎么来解释这个现象呢?那么我们就在这篇博客中一起学习一下label smooth的数学原理以及宏观角度的解释,最后看看最近的论文中有没有更好的label smooth方法。

一、label smooth及其原理和解释

       label smooth是相对于hard label和soft label 而言的,一般的分类任务中我们对label是采用hard label的方式进行one hot编码,而对hard label得到的one hot编码添加一点点噪声。举例如下图来自——如何理解soft target这一做法?

hard label和soft label的优缺点在图中也给出来了,相对来说soft label拥有携带更多的信息,更好的描述数据的类别情况,而hard label丢失了类内和类间的关联,从这个角度来看soft label确实能在一定程度上提高模型的泛化能力,也就是相同数据能提点。

分类问题中, 假设样本 x 的标签为 j ,  \hat{y}为样本对应的预测概率(即softmax的结果)。交叉熵损失如下:

神经网络的输出称为logits,简记为z,经过softmax之后转化为和为1的概率形式,记为\hat{y},真值target记为, 为分类类别的数量。由softmax公式可得:

 当模型的loss为0的时候y=\hat{y},当样本为真样本的时候\hat{y}_{true} =1 , \hat{y}_{false} =0,可以得出: ​

 最终结果是:z_{true}\rightarrow C, z_{false}\rightarrow -\infty 什么意思呢?

神经网络在交叉熵损失函数的时候,当模型loss很低的时候(为0的时候),必然是真样本的logits为常数,假样本的logits为负无穷,一般而言,模型的输出由于采用了激活函数以及有界限定之类的logits不可能为无穷大,采用hard label就不会得到最优的结果——也可以直接说对真样本是其softmax值为1,假样本softmax值为0过于绝对!这就是hard label 不好的原因

 label smooth采用soft label的时候情况就不一样了

abel smooth 学习的编码形式如下图,其中\varepsilon是预定义好的一个超参数,一般取值0.1, 是该分类问题的类别个数:

 经过上述类似的推导——详细推导过程参考文章——简单的label smoothing为什么能够涨点呢,导数等于0的情况下,logit的取值

 可见——使用label-smooth时,假样本的logit不会要求是负无穷。且假样本和真样本的logit值有一定大小误差的情况下,loss就会很小很小,这个对模型效果提升肯定是有益的

二、label smooth的实现

label smooth可以直接使用soft label 然后采用KLDIvLoss计算loss。

import torch
def label_smooth(label, n_class=3,alpha=0.1):
    """
    标签平滑
    :param label: 真实lable
    :param n_class: 类别数目
    :param alpha: 平滑系数
    :return:
    """
    k = alpha / (n_class - 1)
    # temp [batch_size,n_class]
    temp = torch.full((label.shape[0], n_class), k)
    # scatter_.(int dim, Tensor index, Tensor src),这个函数比较难理解——用src张量根据dim和index来修改temp中的元素
    temp = temp.scatter_(1, label.unsqueeze(1), (1-alpha))
    return temp

也可以把soft label以及计算loss的过程统一封装起来,实现一个新的loss function实现如下:

"""
标签平滑
可以把真实标签平滑集成在loss函数里面,然后计算loss
也可以直接在loss函数外面执行标签平滑,然后计算散度loss
"""
import torch.nn as nn
import torch

class LabelSmoothingLoss(nn.Module):
    """
    标签平滑Loss
    """
    def __init__(self, classes, smoothing=0.0, dim=-1):
        """

        :param classes: 类别数目
        :param smoothing: 平滑系数
        :param dim: loss计算平均值的维度
        """
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim
        self.loss = nn.KLDivLoss()

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        #torch.mean(torch.sum(-true_dist * pred, dim=self.dim))就是按照公式来计算损失
        loss = torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
        #采用KLDivLoss来计算
        loss = self.loss(pred,true_dist)
        return loss

forward中的计算也有两种方式 一种是直接采用KLDIvLoss来计算,一种是采用公式一步一步的计算

三、最新的label smooth——在线学习label smooth

《Delving Deep into Label Smoothing》这篇论文就提供了一种在线标签平滑策略方法,使用一种在线学习的方式来生成soft label,相比传统的soft label方法,论文提出的方法声称效提高分类性能和模型的鲁棒性,优于LS、Bootsoft等方法。

原理步骤图

 算法流程

新的损失函数

公式四就是最终的loss函数 

具体实现,放上一份别人实现的代码:

import torch
import torch.nn as nn
from torch import Tensor


class OnlineLabelSmoothing(nn.Module):
    """
    Implements Online Label Smoothing from paper
    https://arxiv.org/pdf/2011.12562.pdf
    使用方法
    from ols import OnlineLabelSmoothing

    criterion = OnlineLabelSmoothing(alpha=..., n_classes=...)
    for epoch in range(...):  # loop over the dataset multiple times
        for i, data in enumerate(...):
            inputs, labels = data
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch} finished!')
        # Update the soft labels for next epoch
        criterion.next_epoch()

        criterion.eval()
        dev()/test()


    """

    def __init__(self, alpha: float, n_classes: int, smoothing: float = 0.1):
        """
        :param alpha: Term for balancing soft_loss and hard_loss
        :param n_classes: Number of classes of the classification problem
        :param smoothing: Smoothing factor to be used during first epoch in soft_loss
        """
        super(OnlineLabelSmoothing, self).__init__()
        assert 0 <= alpha <= 1, 'Alpha must be in range [0, 1]'
        self.a = alpha
        self.n_classes = n_classes

        # Initialize soft labels with normal LS for first epoch
        self.register_buffer('supervise', torch.zeros(n_classes, n_classes))
        self.supervise.fill_(smoothing / (n_classes - 1))
        self.supervise.fill_diagonal_(1 - smoothing)

        # Update matrix is used to supervise next epoch
        self.register_buffer('update', torch.zeros_like(self.supervise))
        # For normalizing we need a count for each class
        self.register_buffer('idx_count', torch.zeros(n_classes))
        self.hard_loss = nn.CrossEntropyLoss()

    def forward(self, y_h: Tensor, y: Tensor):
        # Calculate the final loss
        soft_loss = self.soft_loss(y_h, y)
        hard_loss = self.hard_loss(y_h, y)
        return self.a * hard_loss + (1 - self.a) * soft_loss

    def soft_loss(self, y_h: Tensor, y: Tensor):
        """
        Calculates the soft loss and calls step
        to update `update`.

        :param y_h: Predicted logits.
        :param y: Ground truth labels.

        :return: Calculates the soft loss based on current supervise matrix.
        """
        y_h = y_h.log_softmax(dim=-1)
        if self.training:
            with torch.no_grad():
                self.step(y_h.exp(), y)
        true_dist = torch.index_select(self.supervise, 1, y).swapaxes(-1, -2)
        return torch.mean(torch.sum(-true_dist * y_h, dim=-1))

    def step(self, y_h: Tensor, y: Tensor) -> None:
        """
        Updates `update` with the probabilities
        of the correct predictions and updates `idx_count` counter for
        later normalization.

        Steps:
            1. Calculate correct classified examples.
            2. Filter `y_h` based on the correct classified.
            3. Add `y_h_f` rows to the `j` (based on y_h_idx) column of `memory`.
            4. Keep count of # samples added for each `y_h_idx` column.
            5. Average memory by dividing column-wise by result of step (4).

        Note on (5): This is done outside this function since we only need to
                     normalize at the end of the epoch.
        """
        # 1. Calculate predicted classes
        y_h_idx = y_h.argmax(dim=-1)
        # 2. Filter only correct
        mask = torch.eq(y_h_idx, y)
        y_h_c = y_h[mask]
        y_h_idx_c = y_h_idx[mask]
        # 3. Add y_h probabilities rows as columns to `memory`
        self.update.index_add_(1, y_h_idx_c, y_h_c.swapaxes(-1, -2))
        # 4. Update `idx_count`
        self.idx_count.index_add_(0, y_h_idx_c, torch.ones_like(y_h_idx_c, dtype=torch.float32))

    def next_epoch(self) -> None:
        """
        This function should be called at the end of the epoch.

        It basically sets the `supervise` matrix to be the `update`
        and re-initializes to zero this last matrix and `idx_count`.
        """
        # 5. Divide memory by `idx_count` to obtain average (column-wise)
        self.idx_count[torch.eq(self.idx_count, 0)] = 1  # Avoid 0 denominator
        # Normalize by taking the average
        self.update /= self.idx_count
        self.idx_count.zero_()
        self.supervise = self.update
        self.update = self.update.clone().zero_()

实际效果待验证!

参考文章

简单的label smoothing为什么能够涨点呢

如何理解soft target这一做法?

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值