【论文笔记_优化_2021】SHARPNESS-AWARE MINIMIZATION FOR EFFICIENTLY IMPROVING GENERALIZATION

有效提高泛化能力的清晰度感知最小化
请添加图片描述

摘要

在当今高度过度参数化的模型中,训练损失的值几乎不能保证模型的泛化能力。事实上,通常只优化训练损失值,很容易导致次优的模型质量。受先前工作的启发,我们引入了一种新颖、有效的方法来同时最小化损失值和损失锐度。特别地,我们的过程,清晰度感知最小化(SAM),寻找位于具有一致低损失的邻域中的参数;这个公式导致最小最大优化问题,在这个问题上可以有效地执行梯度下降。我们提供的实证结果表明,SAM提高了各种基准数据集(例如,CIFAR-{10,100}、ImageNet、微调任务)和模型的模型泛化能力,为几个数据集和模型带来了新的最先进的性能。此外,我们发现SAM本身提供了对标签噪声的鲁棒性,与专门针对具有噪声标签的学习的最先进的过程所提供的鲁棒性相当。我们在https://github . com/Google-research/Sam上开源代码。

1.介绍

现代机器学习在广泛的任务上实现更好的性能的成功在很大程度上依赖于越来越重的过度参数化,以及开发越来越有效的训练算法,这些算法能够找到能够很好地概括的参数。事实上,许多现代神经网络可以很容易地记住训练数据,并有能力很容易过拟合(张等人,2016)。这种严重的过度参数化目前需要在多个领域实现最先进的结果(Tan & Le,2019;科列斯尼科夫等人,2020;黄等,2018)。反过来,使用确保实际选择的参数实际上概括超出训练集的程序来训练这种模型是必要的。

不幸的是,简单地最小化训练集上常用的损失函数(例如,交叉熵)通常不足以实现令人满意的泛化。当今模型的训练损失景观通常是复杂的和非凸的,具有多个局部和全局最小值,并且具有不同的全局最小值,产生具有不同泛化能力的模型(Shirish Keskar等人,2016)。因此,从许多可用的优化器(例如,随机梯度下降(内斯特罗夫,1983),亚当(金马&巴,2014),RMSProp(辛顿等人),和其他(杜奇等人,2011;多扎特,2016;Martens & Grosse,2015))已成为一个重要的设计选择,尽管对其与模型泛化的关系的理解仍处于萌芽状态(Shirish Keskar等人,2016;威尔逊等人,2017;Shirish Keskar & Socher,2017;Agarwal等人,2020;Jacot等人,2018)。与此相关,已经提出了一整套修改训练过程的方法,包括dropout(Srivastava等人,2014年),批量归一化(Ioffe & Szegedy,2015)、随机深度(黄等,2016)、数据扩充(Cubuk等,2018)、混合样本扩充(张等,2017;哈里斯等人,2020)。

请添加图片描述
图1:(左)通过切换到SAM降低了错误率。每个点都是不同的数据集/模型/数据扩充。(中)用SGD训练的ResNet收敛到的锐最小值。(右)与SAM一起训练的ResNet收敛到的宽最小值。

损失景观的几何形状(特别是极小值的平坦性)和概化之间的联系已经从理论和经验的角度进行了广泛的研究(Shirish Keskar等人,2016;Dziugaite & Roy,2017;蒋等,2019)。虽然这种联系有希望实现产生更好的泛化的模型训练的新方法,但专门寻找更平坦的最小值并进一步有效地提高一系列最新模型的泛化的实用有效算法迄今仍难以捉摸(例如,参见(Chaudhari等人,2016;伊兹迈洛夫等,2018);我们在第5节中对之前的工作进行了更详细的讨论。

我们在此提出了一种新的高效、可扩展且有效的方法来提高模型泛化能力,该方法直接利用损失景观的几何形状及其与泛化的联系,并且是对现有技术的有力补充。特别是,我们做出了以下贡献:
1.我们引入了清晰度感知最小化(SAM),这是一种通过同时最小化损失值和损失清晰度来提高模型泛化能力的新方法。SAM通过寻找位于具有一致低损失值的邻域中的参数(而不是如图1的中间和右侧图像所示的仅自身具有低损失值的参数)来起作用,并且可以有效和容易地实现。

2.我们通过一项严格的实证研究表明,使用SAM提高了一系列广泛研究的计算机视觉任务(例如,CIFAR-{10,100}、ImageNet、微调任务)和模型的模型泛化能力,如图1左侧图所示。例如,对于许多已经被深入研究过的任务,例如ImageNet、CIFAR-{10,100}、SVHN、Fashion-MNIST,以及图像分类微调任务的标准集(例如,花卉、斯坦福汽车、牛津宠物等),应用SAM会得到最佳性能。

3.我们表明,SAM还提供了对标签噪声的鲁棒性,与专门针对带噪声标签的学习的最新程序提供的鲁棒性相当。

4.通过SAM提供的镜头,我们提出了一个很有前途的新的锐度概念,我们称之为m-锐度,从而进一步阐明了锐度损失和泛化之间的联系。

下面的第2节推导了SAM过程,并详细介绍了由此产生的算法。第3节对SAM进行了经验评估,第4节通过SAM的镜头进一步分析了损失锐度和泛化之间的联系。最后,我们在第5节和第6节分别总结了相关的工作并讨论了结论和未来的工作。

2.清晰度(锐度)感知最小化(SAM)

请添加图片描述
利用LS(w)作为LD(w)的估计值推动了选择参数w的标准方法,即通过使用优化程序(如SGD或Adam)求解最小LS(w)(可能与w上的正则化因子结合)来选择参数w。然而,不幸的是,对于现代的过度参数化模型,如深度神经网络,典型的优化方法很容易在测试时导致次优性能。特别地,对于现代模型,LS(w)通常在w中是非凸的,具有多个局部甚至全局最小值,这些最小值可以产生相似的LS(w)值,同时具有显著不同的泛化性能(即,显著不同的LD(w)值)。

受损失图的锐度和泛化之间的联系的激励,我们提出了一种不同的方法:不是寻找简单地具有低训练损失值LS(w)的参数值w,而是寻找整个邻域具有一致低训练损失值(等价地,具有低损失和低曲率的邻域)的参数值。下面的定理通过在邻域式训练损失方面限制泛化能力来说明这种方法的动机(完整的定理陈述和证明在附录A中):

理论1:对于任何ρ > 0,从分布D生成的训练集S上具有高概率,请添加图片描述
其中,h:R+——>R+是一个严格递增的函数(在LD(w)上的某些技术条件下).为了明确我们的锐度项,我们可以将上面不等式的右边改写为:
请添加图片描述
方括号中的项通过测量通过从w移动到附近的参数值训练损失可以增加多快来捕捉LS在w处的锐度;然后,该锐度项与训练损失值本身和w幅度上的正则项相加。考虑到特定函数h受证明细节的严重影响,我们用λ||w||22代替超参数λ的第二项,从而产生标准的L2正则项。因此,受边界项的启发,我们建议通过解决以下清晰度感知最小化(SAM)问题来选择参数值:
…(数学证明部分)

这种对∇wLSAM S (w)的近似可以通过自动微分直接计算出来,就像在JAX、TensorFlow和PyTorch等常用库中实现的那样。虽然这个计算隐含地依赖于LS(w)的Hessian,因为ˇε(w)本身是∇wLS(w的函数),Hessian仅通过Hessian矢量积进入,这可以在不实现Hessian矩阵的情况下容易地计算。然而,为了进一步加速计算,我们去掉了二阶项。获得我们的最终梯度近似值:

请添加图片描述
如第3节中的结果所示,这种近似(没有二阶项)产生了一种有效的算法。在附录C.4中,我们还研究了包含二阶项的影响;在最初的实验中,包含这些术语会令人惊讶地降低性能,进一步研究这些术语的影响应该是未来工作的重点

我们通过对SAM目标LSAMS)应用诸如随机梯度下降(SGD)之类的标准数值优化器,使用等式3来计算必要的目标函数梯度,从而获得最终的SAM算法。算法1给出了完整SAM算法的伪代码,使用SGD作为基本优化器,图2示意性地说明了单个SAM参数更新。

请添加图片描述

请添加图片描述
图2.SAM参数更新示意图。

3.经验评估

为了评估SAM的功效,我们将其应用于一系列不同的任务,包括从头开始的图像分类(包括在CIFAR-10、CIFAR-100和ImageNet上),微调预训练的模型,以及使用有噪声的标签进行学习。在所有情况下,我们通过简单地用SAM替换用于训练现有模型的优化过程,并计算对模型泛化的结果影响,来衡量使用SAM的好处。如下所示,在绝大多数情况下,SAM极大地提高了泛化性能。
3.1从零开始的图像分类
我们首先在CIFAR-10和CIFAR-100(没有预训练)上评估SAM对当今最先进模型的泛化的影响:具有ShakeShake正则化的wideres nets(Zagoruyko & Komodakis,2016;Gastaldi,2017)和具有震动降正则化的金字塔网(Han等人,2016;Y amada等人,2018)。请注意,这些模型中的一些已经在之前的工作中进行了大量调整,并包括精心选择的正则化方案,以防止过度拟合;因此,显著提高它们的泛化能力是相当重要的。我们已经确保我们的实现在没有SAM的情况下的泛化性能匹配或超过先前工作中报告的性能(Cubuk等人,2018;Lim等人,2019年).

所有结果都使用基本的数据扩充(水平翻转、填充四个像素和随机裁剪)。我们还评估了更先进的数据增强方法,如剪切正则化(Devries & Taylor,2017年)和自动增强(Cubuk等人,2018年),这些方法被先前的工作用来实现最先进的结果。

SAM有一个超参数ρ(邻域大小),我们使用10%的训练集作为验证集,在{0.01,0.02,0.05,0.1,0.2,0.5}上通过网格搜索进行调整3。请参见附录C.1,了解所有超参数值和其他培训详情。因为每次SAM权重更新需要两次反向传播操作(一次计算?(w)和另一个用于计算最终梯度),我们允许每个非SAM训练运行执行两倍于每个SAM训练运行的历元,并且我们报告每个非SAM训练运行在标准历元计数或双倍历元计数4上实现的最佳分数。我们对报告结果的每个实验条件运行五个独立的副本(每个副本都有独立的权重初始化和数据重排),报告测试集的平均误差(或准确度)和相关的95%置信区间。我们的实施利用了JAX (Bradbury等人,2018年),我们在一台拥有8个NVidia V100 GPUs5的主机上训练所有模型。为了在跨多个加速器并行化时计算SAM更新,我们在加速器之间平均划分每个数据批次,独立计算每个加速器上的SAM梯度,并对得到的子批次SAM梯度进行平均,以获得最终的SAM更新。

如表1所示,SAM提高了针对CIFAR-10和CIFAR-100评估的所有设置的通用性。例如,SAM使简单的WideResNet能够达到1.6%的测试误差,而没有SAM时误差为2.2%。这种增益以前只能通过使用更复杂的模型架构(例如,金字塔网)和正则化方案(例如,摇动-摇动、摇动下降)来获得;SAM提供了一种易于实现、独立于模型的替代方案。此外,即使在已经使用复杂正则化的复杂架构上应用SAM,SAM也能提供改进:例如,将SAM应用于具有ShakeDrop正则化的金字塔网,在CIFAR-100上产生10.3%的误差,据我们所知,这是在不使用额外数据的情况下在该数据集上的新的最先进水平。

除了CIFAR-{10,100},我们还在SVHN (Netzer等人,2011年)和Fashion数据集(肖等人,2017年)上评估了SAM。再一次,SAM使一个简单的WideResNet在这些数据集上达到或超过最先进的精度:SVHN的误差为0.99%,Fashion-MNIST的误差为3.59%。详情见附录B.1。

为了在更大范围内评估SAM的性能,我们将其应用于在ImageNet (Deng et al .,2009)上训练的不同深度(50、101、152)的resnet(He et al .,2015)。在这种情况下,根据之前的工作(何等人,2015;Szegedy等人,2015),我们调整图像大小并将其裁剪为224像素分辨率,将其归一化,并使用批量大小4096,初始学习率1.0,余弦学习率调度,SGD优化器,动量为0.9,标签平滑为0.1,权重衰减为0.0001。当应用SAM时,我们使用ρ = 0.05(通过在经过100个时期训练的ResNet-50上进行网格搜索来确定)。我们使用Google Cloud TPUv3在ImageNet上训练所有模型多达400个时期,并报告每个实验条件下的前1名和前5名测试错误率(5次独立运行的平均值和95%置信区间)。

表1.SAM在CIFAR-{10,100 }(WRN = WideResNet;AA =自动增强;SGD是用于训练这些模型的标准非SAM过程)。
请添加图片描述
如表2所示,SAM再次持续提高性能,例如将ResNet-152的ImageNet top-1错误率从20.3%提高到18.4%。此外,请注意,SAM能够增加训练时期的数量,同时在不过度拟合的情况下继续提高精度。相比之下,标准训练程序(没有SAM)通常明显过度拟合,因为训练从200个时期延伸到400个时期。
请添加图片描述
表2:使用和不使用SAM时,在ImageNet上训练的ResNets的测试错误率。

3.2微调
通过在大型相关数据集上预训练模型,然后在感兴趣的较小目标数据集上进行微调的迁移学习已经成为一种强大且广泛使用的技术,用于为各种不同的任务产生高质量的模型。我们在这里显示,SAM再次在这种设置中提供了相当大的好处,即使是在微调非常大的、最先进的、已经高性能的模型时。

特别是,我们应用SAM来微调EfficentNet-b7(在ImageNet上预训练)和efficent net-L2(在ImageNet和未标记的JFT上预训练;输入分辨率475)(谭&乐,2019;Kornblith等人,2018;黄等,2018)。我们将这些模型初始化为公开可用的检查点,分别用RandAugment(在ImageNet上的准确率为84.7%)和NoisyStudent(在ImageNet上的准确率为88.2%)进行训练。我们通过从上述检查点开始训练每个模型,在几个目标数据集的每一个上微调这些模型;有关所用超参数的详细信息,请参见附录。我们报告了每个数据集5次独立运行的前1测试误差的平均值和95%置信区间。

如表3所示,相对于没有SAM的微调,SAM均匀地提高了性能。此外,在许多情况下,SAM产生了新颖的一流性能,包括CIFAR-10上的0.30%误差、CIFAR-100上的3.92%误差和ImageNet上的11.39%误差。

3.3对标签噪声的鲁棒性
SAM寻找对扰动鲁棒的模型参数的事实表明,SAM具有对训练集中的噪声(将扰动训练损失情况)提供鲁棒性的潜力.因此,我们在这里评估SAM对标记噪声提供的稳健程度。

特别地,我们测量了在CIFAR-10的经典噪声标签设置中应用SAM的效果,其中一部分训练集的标签被随机翻转;测试集保持不变(即干净)。为了确保与以前的工作进行有效的比较,以前的工作通常利用专门针对噪声标签设置的架构,我们按照江等人(2019)的方法,为200个时期训练了一个类似大小的简单模型(ResNet-32)。我们评估了模型训练的五种变体:标准SGD、具有Mixup的SGD(张等人,2017)、SAM以及具有Mixup和SAM的SGD的“引导”变体(其中模型首先被照常训练,然后在由初始训练的模型预测的标签上从头开始重新训练)。应用SAM时,对于除80%以外的所有噪声水平,我们使用ρ = 0.1,对于80%我们使用ρ = 0.05以获得更稳定的收敛。对于混合基线,我们尝试了α ∈ {1,8,16,32}的所有值,并保守地报告了每个噪声水平的最佳得分。

如表4所示,SAM对标签噪声提供了高度的鲁棒性,与专门针对带有噪声标签的学习的最新程序所提供的鲁棒性相当。事实上,简单地用SAM训练模型优于所有专门针对标签噪声鲁棒性的现有方法,除了MentorMix (Jiang等人,2019)。然而,简单地引导SAM可以产生与MentorMix相当的性能(后者要复杂得多)。
请添加图片描述
表4:在带有噪声标签的CIFAR-10上训练的模型在干净测试集上的测试准确度。下面的块是我们的实现,上面的块给出了来自文献的分数,根据蒋等人(2019)。

请添加图片描述图3:(左)在使用标准SGD(左栏)或SAM(右栏)训练模型的过程中,Hessian的频谱演变。(中)不同m值的测试误差与ρ的函数关系。(右)不同m值的泛化间隙的m-锐度预测能力(较高表示锐度测量与实际泛化间隙更相关)。

4.透过SAM的镜头看尖锐与概括

4.1 m-锐度
尽管我们的SAM推导定义了整个训练集的SAM目标,但是在实际使用SAM时,我们计算每批的SAM更新(如算法1中所述),或者甚至通过平均每个加速器独立计算的SAM更新(其中每个加速器接收一批的大小为m的子集,如第3节中所述)。后一种设置相当于修改SAM目标(等式1)以对一组独立的ε最大化,每个最大化是在m个数据点的不相交子集上的每个数据点损失的总和上执行的,而不是执行ε对训练集的全局总和进行最大化(这相当于将m设置为总训练集大小)。我们将损失景观的相关锐度度量称为m锐度。

为了更好地了解m对SAM的影响,我们使用SAM在CIFAR-10上训练了一个小型ResNet,其m值范围如图3(中)所示,m值越小,模型的泛化能力越强。这种关系恰好符合跨多个加速器并行化的需求,以便为当今的许多模型扩展培训。

有趣的是,如图3(右)所示,随着m的降低,上述m-锐度度量还表现出与模型实际泛化差距的更好相关性。特别地,这意味着m < n的m锐度比上面第2节中的定理1所建议的全训练集度量产生了更好的泛化预测,暗示了理解泛化的未来工作的有趣的新途径。

4.2HESSIAN光谱
受损失景观的几何形状和一般化之间的联系的激励,我们构造SAM来寻找具有低损失值和低曲率(即,低锐度)的训练损失景观的最小值。为了进一步证实SAM确实找到了具有低曲率的最小值,我们在训练期间的不同时期,计算了在CIFAR-10上训练了300步的具有和不具有SAM的WideResNet40-10的Hessian谱(没有批次范数,其倾向于模糊Hessian的解释)。由于参数空间的维度,我们使用Ghorbani等人(2019)的Lanczos算法来近似Hessian谱。

图3(左)报告了产生的Hessian光谱。正如所预期的,用SAM训练的模型收敛到具有较低曲率的最小值,如在特征值的总体分布中所看到的,收敛时的最大特征值(λmax)(无SAM时约为24,有SAM时约为1.0),以及大部分频谱(比率λmax/λ5,通常用作锐度的代表(Jastrzebski等人,2020);无SAM时高达11.4,有SAM时高达2.6)。

5.相关工作

寻找“平坦”极小值的想法可以追溯到Hochreiter & Schmidhuber (1995),其与泛化的联系已经有了重要的研究(Shirish Keskar等人,2016;Dziugaite & Roy,2017;Neyshabur等人,2017;Dinh等人,2017年)。在最近的一项大规模实证研究中,姜等人(2019)研究了40种复杂性度量,并表明基于锐度的度量与泛化的相关性最高,这促使惩罚锐度。Hochreiter & Schmidhuber (1997)可能是第一篇惩罚清晰度的论文,规范了与最小描述长度(MDL)相关的概念。其他也不利于尖锐最小值的想法包括对分散损失景观进行操作(Mobahi,2016年)和正则化局部熵(Chaudhari等人,2016年)

另一个方向是不明确惩罚锐度,而是在训练中平均权重;伊兹迈洛夫等人(2018)表明,这样做可以产生更平坦的最小值,也可以更好地概括。然而,先前提出的锐度度量难以计算和区分。相比之下,SAM是高度可扩展的,因为它每次迭代只需要两次梯度计算。Sun等人(2020)的并行工作集中于对随机和敌对腐败的弹性,以暴露模型的脆弱性;这项工作可能是最接近我们的。我们的工作有一个不同的基础:我们在一般化的原则起点的激励下开发SAM,通过严格的大规模经验评估清楚地证明SAM的功效,并展示该程序的重要实践和理论方面(例如,m-sharpness)。魏和马(2020)提出的全层边际概念与这项工作密切相关;一个是对网络激活的对抗性扰动,另一个是对其权重的对抗性扰动,这两个量之间存在一些耦合。

6.讨论和未来工作

在这项工作中,我们介绍了SAM,一种新的算法,通过同时最小化损失值和损失锐度来提高泛化能力;我们已经通过严格的大规模实证评估证明了SAM的功效。我们已经为未来的工作提出了一些有趣的途径。在理论方面,m-sharpness产生的每数据点锐度的概念(与过去通常研究的在整个训练集上计算的全局锐度相反)提出了一种有趣的新透镜,通过该透镜可以研究泛化。在方法学上,我们的结果表明,SAM有可能在目前依赖Mixup的健壮或半监督方法中代替Mixup(例如,给予MentorSAM)。我们把对这些可能性的更深入研究留给未来的工作。

  • 3
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值