Sharpness-Aware Minimization for Efficiently Improving Generalization

本文介绍了一种名为Sharpness-aware最小化(SAM)的方法,该方法通过优化训练过程来改善模型的泛化能力。SAM能够在不增加额外计算成本的情况下找到更加平滑的局部最小值,从而提高模型在未知数据上的表现。实验表明,这种方法不仅能提高准确率,还能增强模型对标签噪声的鲁棒性。
摘要由CSDN通过智能技术生成

文章目录

Foret P., Kleiner A., Mobahi H., Neyshabur B. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations.

在训练的时候对权重加扰动能增强泛化性.

主要内容

如上图所示, 一般的训练方法虽然能够收敛到一个不错的局部最优点, 但是往往这个局部最优点附近是非常不光滑的, 即对权重 w w w添加微小的扰动 w + ϵ w+\epsilon w+ϵ 可能就会导致不好的结果, 作者认为这与模型的泛化性有很大关系(实际上已有别的文章提出这一观点).

作者给出如下的理论分析:

在满足一定条件下有
L D ( w ) ≤ max ⁡ ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) + h ( ∥ w ∥ 2 2 / ρ 2 ) . L_{\mathscr{D}} (w) \le \max_{\|\epsilon \|_2 \le \rho} L_{\mathcal{S}} (w + \epsilon) + h(\|w\|_2^2/\rho^2). LD(w)ϵ2ρmaxLS(w+ϵ)+h(w22/ρ2).
其中 h h h是一个严格单调递增函数, L S L_{\mathcal{S}} LS是在训练集 S \mathcal{S} S上的损失,
L D ( w ) = E ( x , y ) ∼ D [ l ( x , y ; w ) ] . L_{\mathscr{D}}(w) = \mathbb{E}_{(x, y) \sim \mathscr{D}} [l(x, y;w)]. LD(w)=E(x,y)D[l(x,y;w)].

如果把 h ( ∥ w ∥ 2 2 / ρ 2 ) h(\|w\|_2^2/\rho^2) h(w22/ρ2)看成 λ ∥ w ∥ 2 2 \lambda \|w\|_2^2 λw22(即常用的weight decay), 我们的目标函数可以认为是
min ⁡ w L S S A M ( w ) + λ ∥ w ∥ 2 2 , \min_w L_{\mathcal{S}}^{SAM} (w) + \lambda \|w\|_2^2, wminLSSAM(w)+λw22,

L S S A M ( w ) : = max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) , L_{\mathcal{S}}^{SAM}(w) := \max_{\|\epsilon \|_p \le \rho} L_{\mathcal{S}} (w + \epsilon), LSSAM(w):=ϵpρmaxLS(w+ϵ),
注: 这里 ∥ ⋅ ∥ p \|\cdot \|_p p而并不仅限于 ∥ ⋅ ∥ 2 \|\cdot \|_2 2.

采用近似的方法求解上面的问题(就和对抗样本一样):
ϵ ∗ ( w ) : = arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) ≈ arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w ) + ϵ T ∇ w L S ( w ) = arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ ϵ T ∇ w L S ( w ) . \epsilon^* (w) := \mathop{\arg \max} \limits_{\|\epsilon\|_p\le \rho} L_{\mathcal{S}}(w + \epsilon) \approx \mathop{\arg \max} \limits_{\|\epsilon\|_p\le \rho} L_{\mathcal{S}}(w) + \epsilon^T \nabla_w L_{\mathcal{S}}(w) = \mathop{\arg \max} \limits_{\|\epsilon\|_p\le \rho} \epsilon^T \nabla_w L_{\mathcal{S}}(w). ϵ(w):=ϵpρargmaxLS(w+ϵ)ϵpρargmaxLS(w)+ϵTwLS(w)=ϵpρargmaxϵTwLS(w).

就是一个对偶范数的问题.

虽然 ϵ ∗ ( w ) \epsilon^*(w) ϵ(w)实际上是和 w w w有关的, 但是在实际中只是当初普通的量带入, 这样就不用计算二阶导数了, 即
∇ w L S S A M ( w ) ≈ ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) . \nabla_w L_{\mathcal{S}}^{SAM}(w) \approx \nabla_w L_{\mathcal{S}}(w) |_{w + \hat{\epsilon}(w)}. wLSSAM(w)wLS(w)w+ϵ^(w).

实验结果非常好, 不仅能够提高普通的正确率, 在标签受到污染的情况下也能有很好的鲁棒性.

代码

原文代码

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值