SAM解析:Sharpness-Aware Minimization for Efficiently Improving Generalization

本文介绍了Sharpness-Aware Minimization (SAM) 方法,该方法旨在改善深度神经网络的泛化能力。通过对目标函数的解析,展示了如何通过最大化训练损失在参数附近的增益来考虑模型的锐度,进而调整优化目标。SAM通过近似内层最大化问题并使用一阶泰勒展开来计算梯度,提出了一个新的优化目标,这有助于找到更稳定的权重。实验表明,这种方法可以有效地提高模型的泛化性能。
摘要由CSDN通过智能技术生成

论文:Sharpness-Aware Minimization for Efficiently Improving Generalization( ICLR 2021)

一、理论

综合了另一篇论文:ASAM: Adaptive Sharpness-Aware Minimization
for Scale-Invariant Learning of Deep Neural Networks 对理论部分这边的解释,同时这篇论文自己也对SAM做出了改进。

1.要解决什么问题?

(1)目标函数

L Q ( w ) L_{\mathscr{Q}}(\boldsymbol{w}) LQ(w)是总体损失值,而训练集 L S ( w ) L_{\mathcal{S}}(\boldsymbol{w}) LS(w)是样本损失值,即训练集S的损失值(S i.i.d. from distribution D \mathscr{D} D),我们使用 L S ( w ) L_{\mathcal{S}}(\boldsymbol{w}) LS(w)来估计 L Q ( w ) L_{\mathscr{Q}}(\boldsymbol{w}) LQ(w)

根据PAC-Bayesian generalization bound定理的非正式形式(严格的定理以及其证明在论文的附录 A.1):

对于任何 ρ > 0 \rho>0 ρ>0 ρ \rho ρ为领域大小),从分布来看,生成的训练集大概率满足:
L D ( w ) ≤ max ⁡ ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) + h ( ∥ w ∥ 2 2 / ρ 2 ) L_{\mathscr{D}}(\boldsymbol{w}) \leq \max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})+h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) LD(w)ϵ2ρmaxLS(w+ϵ)+h(w22/ρ2) 其中, h : R + → R + h: \mathbb{R}_{+} \rightarrow \mathbb{R}_{+} h:R+R+是单调递增函数。附录 A.1结尾给出了具体这个 h h h函数是什么。
L D ( w ) ≤ max ⁡ ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) + + k log ⁡ ( 1 + ∥ w ∥ 2 2 ρ 2 ( 1 + log ⁡ ( n ) k ) 2 ) + 4 log ⁡ n δ + 8 log ⁡ ( 6 n + 3 k ) n − 1 \begin{aligned} L_{\mathscr{D}}(\boldsymbol{w}) & \leq \max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})+\\ &+\sqrt{\frac{k \log \left(1+\frac{\|\boldsymbol{w}\|_{2}^{2}}{\rho^{2}}\left(1+\sqrt{\frac{\log (n)}{k}}\right)^{2}\right)+4 \log \frac{n}{\delta}+8 \log (6 n+3 k)}{n-1}} \end{aligned} LD(w)ϵ2ρmaxLS(w+ϵ)++n1klog(1+ρ2w22(1+klog(n) )2)+4logδn+8log(6n+3k)
那么也就是说,为了得到损失函数 L Q ( w ) L_{\mathscr{Q}}(\boldsymbol{w}) LQ(w)的极小解,现在要求:不等式右边式子,即 max ⁡ ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) + h ( ∥ w ∥ 2 2 / ρ 2 ) \max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})+h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) maxϵ2ρLS(w+ϵ)+h(w22/ρ2)的最小值,也就是一个极小-极大化minimax问题。

(2)引出锐度sharpness概念

这里把不等式的右边展开为:
[ max ⁡ ∥ ϵ ∥ 2 ≤ ρ L S ( w + ϵ ) − L S ( w ) ] + L S ( w ) + h ( ∥ w ∥ 2 2 / ρ 2 ) \left[\max _{\|\epsilon\|_{2} \leq \rho} L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})-L_{\mathcal{S}}(\boldsymbol{w})\right]+L_{\mathcal{S}}(\boldsymbol{w})+h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) [ϵ2ρmaxLS(w+ϵ)LS(w)]+LS(w)+h(w22/ρ2)
这个中括号里的就是 L S L_{\mathcal{S}} LS锐度sharpness,可以衡量从 w \boldsymbol{w} w移动到相临近的参数值,training loss增加的速度。这个式子就变成sharpness + training loss+ w w w的正则项了。

(3)进一步对目标函数进行变换

考虑到具体的函数 h h h受证明细节的影响严重,把 h ( ∥ w ∥ 2 2 / ρ 2 ) h\left(\|\boldsymbol{w}\|_{2}^{2} / \rho^{2}\right) h(w22/ρ2)替换成 λ ∥ w ∥ 2 2 \lambda\|w\|_{2}^{2} λw22,这样就产生了一个标准的L2正则项了, λ ∥ w ∥ 2 2 \lambda\|w\|_{2}^{2} λw22 w w w的正则项, λ \lambda λ即为weight decay。

综合以上,现在的目标函数可以变成(公式0):
min ⁡ w L S S A M ( w ) + λ ∥ w ∥ 2 2  where  L S S A M ( w ) ≜ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) \min _{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w})+\lambda\|\boldsymbol{w}\|_{2}^{2} \quad \text { where } \quad L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) \triangleq \max _{\|\boldsymbol{\epsilon}\|_{p} \leq \rho} L_{S}(\boldsymbol{w}+\boldsymbol{\epsilon}) wminLSSAM(w)+λw22 where LSSAM(w)ϵpρmaxLS(w+ϵ)把它叫做Sharpness-Aware Minimization (SAM) problem。这里, ρ ≥ 0 \rho \geq 0 ρ0是一个超参数,以及 p ∈ [ 1 , ∞ ] p \in[1, \infty] p[1,],论文说 p = 2 p =2 p=2通常是最优的(附录C.5)。

2.怎么解决?

为了最小化 L S S A M ( w ) L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) LSSAM(w),论文提出可以通过对inner maximization求微分来得出 ∇ w L S S A M ( w ) \nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) wLSSAM(w)的近似值。

可是这个the inner maximization是什么东西呢?

the inner problem,指的是在固定损失函数后,如何在训练集上更新模型超参数使得损失函数值最小,得到当前最优的模型参数的问题。对应的数学表达式就是[超参数] = arg min[损失函数]。

inner maximization其实就是上面的使得损失函数值最小改成最大啦,用数学语言表示就是(公式1):
ϵ ∗ ( w ) ≜ arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) \boldsymbol{\epsilon}^{*}(\boldsymbol{w}) \triangleq \underset{\|\boldsymbol{\epsilon}\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) ϵ(w)ϵpρargmaxLS(w+ϵ)意思是:在 w w w ϵ \epsilon ϵ邻域 [ a − ϵ , a + ϵ ] [\mathrm{a}-\epsilon, \mathrm{a}+\epsilon] [aϵ,a+ϵ]中,使得目标函数 L S ( w + ϵ ) L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) LS(w+ϵ)最大的参数 ϵ \epsilon ϵ

然后,y用在 x = w + ϵ x=w+\epsilon x=w+ϵ处, ϵ → 0 \epsilon \to 0 ϵ0一阶泰勒展开式去逼近一个the inner maximization的目标函数:

a. x = x 0 x=x_{0} x=x0处的一阶泰勒展开式:
f ( x ) = f ( x 0 ) + f ′ ( x 0 ) ( x − x 0 ) + o ( x − x 0 ) f(x)=f\left(x_{0}\right)+f^{\prime}\left(x_{0}\right)\left(x-x_{0}\right)+o\left(x-x_{0}\right) f(x)=f(x0)+f(x0)(xx0)+o(xx0)

b. x = w + ϵ x=w+\epsilon x=w+ϵ处, ϵ → 0 \epsilon \to 0 ϵ0的一阶泰勒展开式:
L S ( w + ϵ ) = L S ( w ) + ϵ T ∇ w L S ( w ) + o ( ϵ ) L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon})=L_{\mathcal{S}}(\boldsymbol{w})+ \epsilon^{T}\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})+o\left(\epsilon\right) LS(w+ϵ)=LS(w)+ϵTwLS(w)+o(ϵ) L S ( w + ϵ ) ≈ L S ( w ) + ϵ T ∇ w L S ( w ) L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) \approx L_{\mathcal{S}}(\boldsymbol{w})+ \epsilon^{T}\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}) LS(w+ϵ)LS(w)+ϵTwLS(w)

c. 又因为 arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w ) = 0 \underset{\|\epsilon\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w}) =0 ϵpρargmaxLS(w)=0
把b、c公式代入公式1,得到:
ϵ ∗ ( w ) ≜ arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w + ϵ ) ≈ arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ L S ( w ) + ϵ T ∇ w L S ( w ) = arg ⁡ max ⁡ ∥ ϵ ∥ p ≤ ρ ϵ T ∇ w L S ( w ) \begin{aligned} \boldsymbol{\epsilon}^{*}(\boldsymbol{w}) &\triangleq \underset{\|\boldsymbol{\epsilon}\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w}+\boldsymbol{\epsilon}) \\ & \approx \underset{\|\epsilon\|_{p} \leq \rho}{\arg \max } L_{\mathcal{S}}(\boldsymbol{w})+\boldsymbol{\epsilon}^{T} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}) \\ &=\underset{\|\boldsymbol{\epsilon}\|_{p} \leq \rho}{\arg \max } \boldsymbol{\epsilon}^{T} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}) \end{aligned} ϵ(w)ϵpρargmaxLS(w+ϵ)ϵpρargmaxLS(w)+ϵTwLS(w)=ϵpρargmaxϵTwLS(w)

现在,用 ϵ ^ ( w ) \hat{\boldsymbol{\epsilon}}(\boldsymbol{w}) ϵ^(w)表示 ϵ ( w ) \boldsymbol{\epsilon}(\boldsymbol{w}) ϵ(w)的近似值, ϵ ^ ( w ) \hat{\boldsymbol{\epsilon}}(\boldsymbol{w}) ϵ^(w)可以用经典对偶范数问题(dual
norm problem)的解法来得出:
ϵ ^ ( w ) = ρ sign ⁡ ( ∇ w L S ( w ) ) ∣ ∇ w L S ( w ) ∣ q − 1 / ( ∥ ∇ w L S ( w ) ∥ q q ) 1 / p \hat{\boldsymbol{\epsilon}}(\boldsymbol{w})=\rho \operatorname{sign}\left(\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right)\left|\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|^{q-1} /\left(\left\|\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right\|_{q}^{q}\right)^{1 / p} ϵ^(w)=ρsign(wLS(w))wLS(w)q1/(wLS(w)qq)1/p

其中, 1 / p + 1 / q = 1 1 / p+1 / q=1 1/p+1/q=1

之前说过论文提到 p = 2 p =2 p=2通常效果最优,这里就把 p = 2 p =2 p=2代入到上式中, 1 / p = 1 / 2 1/p = 1/2 1/p=1/2, q − 1 = 1 q-1=1 q1=1,则得到公式2
ϵ ^ ( w ) = ρ sign ⁡ ( ∇ w L S ( w ) ) ∣ ∇ w L S ( w ) ∣ ∥ ∇ w L S ( w ) ∥ 2 = ρ ∇ L S ( w ) ∥ ∇ L S ( w ) ∥ 2 \begin{aligned} \hat{\boldsymbol{\epsilon}}(\boldsymbol{w}) &=\rho \operatorname{sign}\left(\nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right) \frac{|\nabla_{\boldsymbol{w}} L_{S}\left(\mathbf{\boldsymbol{w}}\right)|}{\left\|\nabla_{\boldsymbol{w}} L_{S}\left(\mathbf{\boldsymbol{w}}\right)\right\|_{2}} \\ & = \rho \frac{\nabla L_{S}\left(\mathbf{w}\right)}{\left\|\nabla L_{S}\left(\mathbf{w}\right)\right\|_{2}} \end{aligned} ϵ^(w)=ρsign(wLS(w))wLS(w)2wLS(w)=ρLS(w)2LS(w)
把上式代入到公式0,得到:
∇ w L S S A M ( w ) ≈ ∇ w L S ( w + ϵ ^ ( w ) ) = d ( w + ϵ ^ ( w ) ) d w ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) = ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) + d ϵ ^ ( w ) d w ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) \begin{aligned} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) & \approx \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w})) \\ & =\left.\frac{d(\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w}))}{d \boldsymbol{w}} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(w)} \\ &=\left.\nabla_{w} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w})}+\left.\frac{d \hat{\epsilon}(\boldsymbol{w})}{d \boldsymbol{w}} \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(\boldsymbol{w})\right|_{\boldsymbol{w}+\hat{\epsilon}(\boldsymbol{w})} \end{aligned} wLSSAM(w)wLS(w+ϵ^(w))=dwd(w+ϵ^(w))wLS(w)w+ϵ^(w)=wLS(w)w+ϵ^(w)+dwdϵ^(w)wLS(w)w+ϵ^(w)
上面第二行用到了复合微分:
d f ( g ( x ) ) d x = d g ( x ) d x d f ( x ) ∣ g ( x ) \begin{aligned} &\frac {d f(g(x))}{dx} =\left.\frac{d g(x)}{d x} d f(x)\right|_{g(x)} \end{aligned} dxdf(g(x))=dxdg(x)df(x)g(x)
公式3
∇ w L S S A M ( w ) ≈ ∇ w L S ( w ) ∣ w + ϵ ^ ( w ) \left.\nabla_{\boldsymbol{w}} L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) \approx \nabla_{\boldsymbol{w}} L_{\mathcal{S}}(w)\right|_{\boldsymbol{w}+\hat{\boldsymbol{\epsilon}}(\boldsymbol{w})} wLSSAM(w)wLS(w)w+ϵ^(w)
如Section 3中的结果所示,如果没有二阶项,这个近似值会产生一个有效的算法。在论文附录C.4中具体研究了这种影响,在开始的实验中,加入二阶项会令人惊讶地降低性能,进一步研究这些条款的影响应该是未来工作的优先事项。

我们通过将标准的数值优化器(如随机梯度下降SGD)应用于SAM目标函数 L S S A M ( w ) L_{\mathcal{S}}^{S A M}(\boldsymbol{w}) LSSAM(w),其中使用公式3计算目标函数梯度,从而获得最终的SAM算法。

论文的Algorithm 1给出了SAM算法的伪代码,它使用SGD作为基本优化器。右边的Figure 2示意性地说明了SAM参数的单次迭代更新。

在这里插入图片描述

  • 14
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值