一、标签平滑(Label Smoothing)介绍
标签平滑(Label Smoothing)的原理其实很简单,它大部分的用处用一句话总结就是:
修改数据集的标签来增加扰动,避免模型的判断过于自信从而陷入过拟合
标签平滑是一种正则化的技术,常常用在分类任务中。它的具体做法就是为数据集的标签增加扰动,它的具体做法如下所示(以K分类任务为例)。
对于K分类来说,假设一个样本 x x x属于第2类,那么实际上用来训练模型(或者说用来计算损失函数)的标签是一个独热编码,具体为[0,0,1,0], 即在位置为2处数值为1(代表属于第2类(从第0类开始计数))。此时标签平滑的具体步骤为:
- 定义一个小的扰动常量 ϵ \epsilon ϵ
- 将独热编码的标签中的0替换为 ϵ / K \epsilon/K ϵ/K
- 将独热编码的标签中的1替换为 1 − e p s i l o n / K 1-epsilon/K 1−epsilon/K
由于在现实数据集中,并不是所有标签都是正确标注的,所以直接最大化 log p ( y ∣ x ) \log{p}\left(y\mid{x}\right) logp(y∣x)(即过于自信的把其中一个候选类对应的digit置为1,将其余类的digit置为零), 反而是有害的。这种过于自信的做法不仅仅会使得模型过拟合,而且有可能拟合到错误的例子上去。
一些实验已经证明,标签平滑能够增加模型的泛化能力(Müller et al., 2020)。
二、标签平滑的代码实现
作为一种相对成熟的技术,标签平滑已经有许多开箱即用的实现,在这里我摘取CoinCheung的实现作为示例。
这个版本主要面向pytorch框架,你可以像使用pytorch中的CrossEntropyLoss一样使用它,无需任何改动。
# version 1: use torch.autograd
class LabelSmoothSoftmaxCEV1(nn.Module):
'''
This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
'''
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
super(LabelSmoothSoftmaxCEV1, self).__init__()
self.lb_smooth = lb_smooth
self.reduction = reduction
self.lb_ignore = ignore_index
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, logits, label):
'''
Same usage method as nn.CrossEntropyLoss:
>>> criteria = LabelSmoothSoftmaxCEV1()
>>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
>>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
>>> loss = criteria(logits, lbs)
'''
# overcome ignored label
logits = logits.float() # use fp32 to avoid nan
with torch.no_grad():
num_classes = logits.size(1)
label = label.clone().detach()
ignore = label.eq(self.lb_ignore)
n_valid = ignore.eq(0).sum()
label[ignore] = 0
lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
lb_one_hot = torch.empty_like(logits).fill_(
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
logs = self.log_softmax(logits)
loss = -torch.sum(logs * lb_one_hot, dim=1)
loss[ignore] = 0
if self.reduction == 'mean':
loss = loss.sum() / n_valid
if self.reduction == 'sum':
loss = loss.sum()
return loss
三、参考资料
- https://paperswithcode.com/method/label-smoothing
- https://arxiv.org/pdf/1906.02629.pdf
- https://github.com/CoinCheung/pytorch-loss/blob/master/label_smooth.py