标签平滑(label-smoothing)主要用于防止过拟合,增强模型的泛化能力。在one-hot的基础上,添加一个平滑系数
ε
\varepsilon
ε,使得最大预测与其它类别平均值之间差距的经验分布更加平滑。
Pytorch代码实现
import torch
def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
"""
if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method
"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes)) # torch.Size([2, 5])
with torch.no_grad():
true_dist = torch.empty(size=label_shape, device=true_labels.device) # 空的,没有初始化
true_dist.fill_(smoothing / (classes - 1))
_, index = torch.max(true_labels, 1)
true_dist.scatter_(1, torch.LongTensor(index.unsqueeze(1)), confidence) # 必须要torch.LongTensor()
return true_dist
true_labels = torch.zeros(2, 5)
true_labels[0, 1], true_labels[1, 3] = 1, 1
print('标签平滑前:\n', true_labels)
true_dist = smooth_one_hot(true_labels, classes=5, smoothing=0.1)
print('标签平滑后:\n', true_dist)
'''
Loss = CrossEntropyLoss(NonSparse=True, ...)
. . .
data = ...
labels = ...
outputs = model(data)
smooth_label = smooth_one_hot(labels, ...)
loss = (outputs, smooth_label)
...
'''
打印出来结果:
标签平滑前:
tensor([[0., 1., 0., 0., 0.],
[0., 0., 0., 1., 0.]])
标签平滑后:
tensor([[0.0250, 0.9000, 0.0250, 0.0250, 0.0250],
[0.0250, 0.0250, 0.0250, 0.9000, 0.0250]])