SHARPNESS-AWARE MINIMIZATION FOR EFFICIENTLY IMPROVING GENERALIZATION-论文精读
锐度感知最小化 (SAM)paper-link 是一种简单而有效的技术,用于提高深度学习模型在看不见的数据样本上的泛化能力。
摘要
在当今过度参数化的模型中,训练损失的值很少能保证模型的泛化能力。事实上,仅优化训练损失值(如通常所做的那样)很容易导致次优的模型质量。
受先前将损失景观的几何形状与泛化能力相关联的工作的启发,我们引入了一种新颖而有效的方法,该方法通过同时最小化损失值和损失锐度来提高模型泛化能力。特别是,我们的方法,锐度感知最小化 (SAM),寻求位于具有均匀低损失的邻域中的参数;这种公式导致了一个可以在其上有效地进行梯度下降的最小-最大优化问题。我们展示了实验结果,表明 SAM 在各种基准数据集(例如 CIFAR-{10, 100}、ImageNet、微调任务)和模型上提高了模型的泛化能力,为其中一些模型带来了新的最先进性能。此外,我们发现 SAM 本身就具有与专门针对噪声标签学习而设计的最先进程序相当的对标签噪声的鲁棒性。
代码: https://github.com/google-research/sam
关键词:SAM (锐度感知最小化),泛化能力,损失锐度,梯度下降,损失函数
核心思想
- 锐度定义:
SAM 将锐度定义为在给定参数值附近,损失值随参数变化的速率。通过最小化 SAM 损失,SAM 寻找损失函数在参数空间中具有低锐度的最小值。 - 算法步骤:
SAM 算法可以与任何梯度下降优化器结合使用,例如 SGD 或 Adam。
- 在每次迭代中,SAM 算法首先计算当前参数的梯度。
- 然后,它使用梯度的方向和大小来生成一个“对抗性扰动”,该扰动将沿着梯度的方向将参数移动到损失增加最多的方向。
SAM 算法计算参数加上对抗性扰动后的损失值,并将其作为 SAM 损失。 - 最后,SAM 算法使用梯度下降优化器来最小化 SAM 损失,从而更新参数。
- 锐度感知梯度计算:
SAM 算法使用一阶泰勒展开来近似 SAM 损失的梯度。
这使得 SAM 算法可以有效地计算梯度,而无需显式地计算损失函数的 Hessian 矩阵。 - 优势:
SAM 算法简单易实现,可以与任何梯度下降优化器结合使用;可以显著提高模型的泛化能力,尤其是在过度参数化的模型中;对标签噪声具有鲁棒性,可以提高模型在噪声数据上的性能。 - 参数:
SAM 算法有一个重要的参数 ρ,它控制着参数空间中考虑的邻域大小,ρ 的选择会影响 SAM 算法的性能,需要根据具体任务进行调整。
代码实现
def sam_loss(loss_fn, model, inputs, targets, rho):
"""
计算 SAM 损失函数的近似值。
"""
# 计算当前参数的梯度
model.zero_grad()
loss = loss_fn(model(inputs), targets)
loss.backward()
# 生成对抗性扰动
grad_norm = torch.norm(model.parameters(), p=2)
epsilon = rho * torch.sign(model.parameters()) * torch.norm(model.parameters(), p=2).reciprocal()
# 计算参数加上对抗性扰动后的损失值
perturbed_model = copy.deepcopy(model)
perturbed_model.load_state_dict(model.state_dict())
for name, param in perturbed_model.named_parameters():
param.data += epsilon[name]
perturbed_loss = loss_fn(perturbed_model(inputs), targets)
# 返回 SAM 损失的近似值
return max(perturbed_loss, loss)