导读
谈到 Weight Decay,势必大家都不陌生,因为这是每位 AIer 每天都会用的计数,但却又是几乎没人会关注的算法。
即使身处在以 ChatGPT 为代表所迸发的 LLM 时代,它还是那么的有用,却又那么的朴实无华。
即便是在机器学习理论圈,研究过 Weight Decay 机制的人也是少数派。
今天,为大家介绍的是基于由东京大学等机构发表于 NeuraIPS 2023 的一篇文献 《On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective》 所延伸开来的一篇关于经典算法——Weight Decay
的讨论。
请注意,这不是一篇翻译文,而更多地是为了向大家普及知识,详细说明完成这项工作的脉络。因此,请大家耐心阅读,相信你们会从中获得一些启发和收获。
背景
在正式开始之前,我想很有必要回顾下权重衰减机制的来世今生。其实,现在大家常说的 Weight Decay 至少有三种不同的形式。只是由于 Weight Decay 这个名词的滥用,大家都在不同的场合被称作 Weight Decay,仅此而已。
No.1
第一种 Weight Decay
,也就是当前 PyTorch/TensorFlow/Paddle 等深度学习框架里 optimizer 的默认 Weight Decay。简单点说,其实是L2 Regularization
😃。在 SGD 里我们可以将它写成下面这种形式:
L2 正则化很好理解,实际上是在原有的 loss 旁边额外加多一个惩罚项。所以它的特点是,正则作用会先进入每一步迭代的梯度里,最终影响神经网络的参数。
No.2
第二种 Weight Decay
,也就是当前大模型训练常用优化器 AdamW 里的 Decoupled Weight Decay
。在 SGD 里可以写成下面这种形式:
这才是真正意义上的 decay weights,因为这里每一步直接将权重以 1 − n t λ 1 - n_{t}\lambda 1−ntλ 因子衰减。
好吧,很容易看出来,对于 SGD 来说,上面两种形式是等价的。这也是为什么L2 Regularization
会滥用Weight Decay
这个名字的原因:(
但更复杂的优化器,比如 SGD Momentum 或 Adam,会使用动量(Momentum)和自适应学习率(Adaptive Learning Rate)等方法。这些实际上大家最常用的优化器里,没有一个是能让上面两种形式正好等价的。
Adam (with L2 Regularization)和 AdamW (with Decoupled Weight Decay)这两个大模型训练最常用的两个优化器实际上只有 weight decay 的形式不同。但是他们的性能差异大到以至于分成了两个名字不同的优化器。
那究竟为什么会这样,详听下文讲解。
No.3
第二种 Weight Decay
,也就是在训练神经网络时专门被提出来的原始 Weight Decay,来自一篇 1989 年的 NeurIPS 论文。在 SGD 里可以写成下面这种形式:
这是在做学习率解耦的 decay weights,每一步直接将权重以 1 − λ ′ 1-\lambda^{\prime} 1−λ′ 因子衰减。
哪种更好呢?
看到这里,可能大家会有个疑问,上面提到了三种不同形式的权重衰减方式,那究竟该用哪种呢?
其实,第三种 Weight Decay 是一种已经被淘汰的形式。虽然深度学习框架的开发者们可能没有专门研究过这个问题,但实际上其学习率解耦模式确实不如第二种形式。这是因为对于第三种 Weight Decay 来说,在学习衰减得比较小的训练末期,Weight Decay 的强度总是太大了。我们很容易在实验里观察到第三种 Weight Decay 的缺陷。
而第一种的 L2 Regularization 仍然在当前的深度学习框架流行;而第二种的 Decoupled Weight Decay 几乎只在 Transformer 等架构盛行的 AdamW 等少量优化器里有官方实现。这个现状并不好…
对于一些优化器,比如 SGD(Momentum)等,L2 Regularization 和 Decoupled Weight Decay 的性能并不是很大。但每当 L2 Regularization 和 Decoupled Weight Decay 差别很大时,几乎总是 Decoupled Weight Decay 显著地比 L2 Regularization 好。这几乎发生在所有自适应优化器上。
所以<更好的原则是把 Decoupled Weight Decay 作为Weight Decay 的默认实现,把 L2 Regularization 作为备选。
说到这,挑刺的小伙伴可能又会说,既然这样,那 PyTorch/TensorFlow 等框架为啥直到今天还是将 L2 正则化作为默认配置?此处盲猜可能是因为深度学习框架滥用名词 Weight Decay 指代 L2 Regularization 造成的一个历史遗留问题…
有啥隐藏缺陷?
那么 Weight Decay 的故事就这么结束了吗?
不,上述论述只是个引子,目的就是为了点出当前权重衰减的缺陷。
是的,哪怕是发展了这么多年的已经非常成熟的 Weight Decay,其实还有挺严重的隐藏缺陷未被发现,这便是今天介绍的这个工作的主要贡献,下面让我们一探究竟!
不得不说,如果一个这么常用的算法还有严重的缺陷,那这个缺陷确实藏得挺深。
这件事说起来也有点离谱。我第一次得到前文两个部分的 motivation 已经是3年前的事了。但深度学习理论的论文并不多了,所以一直也没有其他人把 Weight Decay 的这些特点说不清。对于第三种 Weight Decay 为什么比第二种 Weight Decay 差那么多,我一直没有得到一个比较完整的理论解释…
直到一个月黑风高的晚上,の,去年,终于有空搞点基础研究,心血来潮做了一个实验,把 Gradient Norm
随着训练过程的变化可视化出来。
通过观察 ResNet18 在 CIFAR-10 训练时的 Gradient Norm 曲线可以发现,Gradient Norm 会随着 Weight Decay 的增加而显著增加。
好的。如果你有优化理论和泛化理论的基础,那么应该一下就能看出来 Weight Decay 它确实有点大毛病了!不用担心,我知道屏幕前的你看不懂,没关系,我们继续… _
这个实验其实呈现出一个违背传统认知的现象——Gradient Norm 竟然会随着 Weight Decay 增加而显著增加!为什么这样说,因为这个现象至少从三个角度来说,都是和传统的理解是不相符的:
-
优化角度。Large Gradient Norms 说明训练算法的收敛性变成很差,因为从某种意义上来说,gradient norm 的大小就是收敛性最常见的指标。
-
泛化角度。Large Gradient Norms 说明训练得到的权重泛化性很差,因为 gradient norm 的大小也是泛化界常见的度量之一。
-
正则化角度。Large Gradient Norms 说明训练得到的权重复杂性变高了。而传统上,大家认为 Weight Decay 是通过正则化作用来降低神经网络复杂性,从而提高模型的泛化能力。而今天这实验的结果竟然正好相反,你说怪不怪?
因此,从这三个角度来说,Weight Decay 都大有问题。这也再次说明,在深度学习里,很多传统理解都是很粗浅的、适用范围狭窄的。当然,大家都早已习惯做屌爆侠。
关于这三点,原文里提供了详细的分析和参考文献,感兴趣的小伙伴可以去研读。
找到病因就好办了,因为接下来需要做的事情便是对症下药。于是,花了一天时间,最终把 Gradient Norm 的 Upper Bound 和 Lower Bound 都和 Weight Decay 强度正相关的理论证明了一下,算是把 Weight Decay 过去被忽略的关于 gradient norm 理论补全了。我们理论也发现这些缺陷在自适应优化器里尤为突出。这部分都是一个理论性论文里的程序性技术工作。
如何对症下药?
本文工作设计了一个算法——Scheduled Weight Decay
来弥补 Weight Decay 的上述缺陷,也就是使用 Weight Decay 的时候,同时可以抑制 Gradient Norm。
这个方法思想上也很简单,就是当 Gradient Norm 太大的时候就让 Weight Decay 强度小一点,Gradient Norm 太小的时候就让 Weight Decay 强度大一点发挥作用。
如下图8所示,所提优化算法 AdamS (Adam with Scheduled Weight Decay) 的确很好地完成了抑制 Gradient Norm 的作用。
可以看出,Scheduled Weight Decay 相比 Constant Weight Decay 显著降低训练末期的 Gradient Norm!
同时,在上图9我们还能看到,随着 Weight Decay 增加时找到了更 sharp 的 minima(top Hessian eigenvalues 增大),这通常意味着泛化性能不好。而 Scheduled Weight Decay 找到了比 Constant Weight Decay 更 flat 的 minima(top Hessian eigenvalues明显变小)。
一点浅显的分析:
在深度学习中,Hessian 矩阵是损失函数相对于模型参数的二阶导数矩阵。对 Hessian 矩阵进行特征值分解,得到的特征值中,top Hessian eigenvalues 表示最大的特征值。这些特征值提供了关于损失曲面形状的信息。
此外,在训练神经网络时,找到损失函数的极小值其实是一个优化问题。而 Sharp minima 便是损失函数曲面上陡峭的最小值,而 flat minima 则是相对较平缓的最小值。特征值的大小和损失曲面的形状有关,top Hessian eigenvalues 表示 Hessian 矩阵中最大的特征值,因此反映了曲面在该方向上的曲率。
上面我们提到作者通过观察 top Hessian eigenvalues 随着 Weight Decay 变化的趋势,可以得知损失曲面的形状如何随之变化。当 top Hessian eigenvalues 增大时,表示曲面更陡峭,这可能导致模型更容易陷入 sharp minima,而泛化性能可能较差。相反,当 top Hessian eigenvalues 变小时,表示曲面相对平坦,模型更可能落入 flat minima,这可能对泛化性能更有利。因此,只需调整权重衰减的对应策略,使得 top Hessian eigenvalues 减小,便意味着找到的 minima 相对更 flat。这种调整是有助于改善模型的泛化性能。
一点言外之意
一些审稿人会把这个算法当作最主要的贡献,这说明这部分审稿人始终是在从 Engineering 的角度来看待这个工作的。直白点说就是,他们认为 AI 只是一个黑盒机器,只要能 work 得更好,那便是研究的全部。
应用研究的主(quan)要(bu)哲学——那确实就是 work 得更好就行。带实习生做 Neural Fields 方向的应用研究最近半年也分别被 ICCV2023 和 ICLR2024,一投既中。只要性能好了,基本上就没有遇到过负分的审稿人。老实说,这种对性能的朴素追求,确实有一点朴实无华…
说到这,可能大部分人都对“刷榜”深恶痛绝,但却又无可奈何,只能被逼着刷,打不过就换个角度刷,实在不行就换个方向再刷,(⊙﹏⊙)~~~
但包括作者在内的部分理论研究者一般更愿意把 AI 当作一个客观的研究对象。这种研究视角可以叫做 Science of AI
。包括 Sébastien Bubeck 在内的部分理论研究者一般是把 AI 当作一个客观的研究对象,也就是我们主要是在做科学发现。偶尔,这些科学发现能促进更好的机器的诞生。微软AI理论研究的带头人 Sébastien Bubeck 在他的主页在大模型的时代背景下更近一步把自己的研究定位为Physics of AGI。
Sébastien Bubeck 主页:http://sbubeck.com/
研究神经网络和研究玻色爱因斯坦凝聚差别真的很大吗?并不,他们肯定要服从某些真理。渐进低逼近真理就是新的科学。
我们另外一篇 NeurIPS2023 其实也是科学发现型的论文,和一位物理系任教的朋友第一天讨论 idea 做实验,然后第二天就在神经网络里发现了我们预言的实验结果。这种并不会出现在论文里的故事,可以留待下回写博客时介绍。
我个人觉得这篇文章最大的贡献其实是定位 Weight Decay 的隐藏缺陷上,算法只是自然而然的副产品罢了。这个工作也是在理论发现完成后立刻被接收了,而不是在算法和实验的时候(很早就在Github上开源这个Weight Decay算法了)。而 Scheduled Weight Decay 也必然不是解决 Weight Decay 隐藏缺陷的最终答案。因为它解决的是 Gradient Norm 的问题,而不是全部的问题。Weight Decay 仍然还有我们不理解的问题。
Github: https://github.com/zeke-xie/stable-weight-decay-regularization
完毕,感谢阅读。