什么是label smoothing?
标签平滑(Label smoothing
),像L1
、L2
和dropout
一样,是机器学习领域的一种正则化方法,通常用于分类问题,目的是防止模型在训练时过于自信地预测标签,改善泛化能力差的问题。
使用label smoothing目的
label smoothing
常用于分类任务,防止模型在训练中过拟合,提高模型的泛化能力。
使用label smoothing
class LabelSmoothingCrossEntropy(nn.Module):
"""
Cross Entropy loss with label smoothing.
"""
def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingCrossEntropy, self).__init__()
assert 0.0 < smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing
def forward(self, x, target):
"""
写法1
"""
# logprobs = F.log_softmax(x, dim=-1)
# nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
# nll_loss = nll_loss.squeeze(1) # 得到交叉熵损失
# # 注意这里要结合公式来理解,同时留意预测正确的那个类,也有a/K,其中a为平滑因子,K为类别数
# smooth_loss = -logprobs.mean(dim=1)
# loss = self.confidence * nll_loss + self.smoothing * smooth_loss
"""
写法2
"""
y_hat = torch.softmax(x, dim=1)
# 这里cross_loss和nll_loss等价
cross_loss = self.cross_entropy(y_hat, target)
smooth_loss = -torch.log(y_hat).mean(dim=1)
# smooth_loss也可以用下面的方法计算,注意loga + logb = log(ab)
# smooth_loss = -torch.log(torch.prod(y_hat, dim=1)) / y_hat.shape[1]
loss = self.confidence * cross_loss + self.smoothing * smooth_loss
return loss.mean()
def cross_entropy(self, y_hat, y):
return - torch.log(y_hat[range(len(y_hat)), y])
然后你可能刚开始使用的损失函数是:
lossfunc = nn.CrossEntropyLoss()
只需要改成:
lossfunc = LabelSmoothingCrossEntropy(smoothing=0.1)
即可在你的代码种使用标签平滑。