文章目录
梯度下降中对抗局部极小值与鞍点
一、背景
我们在做优化的时候会发现,随着参数不断更新,训练的损失不会再下降,有时候深层网络并没有相较于浅层网络做得更好——深层网络没有发挥出它完整的力量,因为按理说随着网络层数不断增加,模型的性能会不断提高。这里举ResNet论文中的一个例子1:
如上图所示,这篇论文在测试集上测试两个网络,一个网络有 20 层,一个网络有 56 层。出乎意料的是,随着迭代次数增加, 56 层网络的损失比 20 层网络的损失还高。
这种现象往往说明优化是有问题的。常见的一个猜想是我们优化到某个地方,这个地方损失关于参数的微分为零,这样梯度下降就不能再更新参数,训练就停下来,损失不再下降。导致上述问题的原因之一是,模型收敛到了局部极小值(local minimum)或鞍点(saddle point),神经网络的损失函数是一个非凸函数,找到全局最优解通常比较困难。
补充:这边给大家的建议是看到一个从来没有做过的问题,可以先跑一些比较小的、比较浅的网络,或甚至用一些非深度学习的方法,比如线性模型、支持向量机(Support Vector Machine,SVM)等2,就 SVM 来说,其不容易有优化失败的问题,因为SVM的优化问题是一个典型的凸优化问题(即Hessian矩阵处处半正定),局部最优解就是全局最优解。简单来说,这些模型会竭尽全力的,在它们的能力范围之内,找出一组最好的参数。因此可以先训练一些比较浅的,或是比较简单的模型,先了解这些简单的模型到底可以得到什么样的损失。
二、谁是凶手?局部最优点or鞍点?
从几何意义来解释下鞍点和局部极小值:
局部极小值:在函数曲面上,局部极小值点是一个“谷底”,周围的所有点都比它高。
鞍点:在函数曲面上,鞍点类似于马鞍的形状,沿着某些方向是“谷底”,但是沿着其他方向却是“山峰”。
判断一个驻点(Stationary Point) 到底是局部极小值还是鞍点需要知道损失函数的形状,虽然神经网络的损失函数较为复杂,但是我们可以通过给定的参数,使用多元泰勒展开进行估计,判断二阶项中Hessian矩阵的特征值,即可推断出驻点的种类。假设给定的参数为 θ ′ \theta' θ′ ,损失函数为 L ( θ ) L(\theta) L(θ),则 θ \theta θ 附近的 L ( θ ) L(\theta) L(θ) 可近似为:
L ( θ ) ≈ L ( θ ′ ) + ( θ − θ ′ ) T g + 1 2 ( θ − θ ′ ) T H ( θ − θ ′ ) L(\boldsymbol{\theta}) \approx L\left(\boldsymbol{\theta}^{\prime}\right)+\left(\boldsymbol{\theta}-\boldsymbol{\theta}^{\prime}\right)^{\mathrm{T}} \boldsymbol{g}+\frac{1}{2}\left(\boldsymbol{\theta}-\boldsymbol{\theta}^{\prime}\right)^{\mathrm{T}} \boldsymbol{H}\left(\boldsymbol{\theta}-\boldsymbol{\theta}^{\prime}\right) L(θ)≈L(θ′)+(θ−θ′)Tg+21(θ−θ′)TH(θ−θ′)
其中:
- L ( θ ′ ) L(\theta') L(θ′) :表示损失函数在 θ ′ \theta' θ′ 点的值,这是一个固定值
- ( θ − θ ′ ) T g (\theta - \theta')^T g (θ−θ′)Tg :描述了 L ( θ ) L(\theta) L(θ) 在 θ ′ \theta' θ′ 附近的线性变化,方向由梯度 g g g 决定。如果梯度 g = 0 g = 0 g=0,则说明 θ ′ \theta' θ′ 是一个驻点,可能是局部最优点或者是鞍点。有的时候梯度 g g g 会写成 ∇ L ( θ ′ ) \nabla L(\theta') ∇L(θ′)。
- 1 2 ( θ − θ ′ ) T H ( θ − θ ′ ) \frac{1}{2}\left(\boldsymbol{\theta}-\boldsymbol{\theta}^{\prime}\right)^\mathrm{T}\boldsymbol{H}\left(\boldsymbol{\theta}-\boldsymbol{\theta}^{\prime}\right) 21(θ−θ′)TH(θ−θ′):泰勒展开的第二项,这一项决定了 θ ′ \theta' θ′ 附近的误差表面(error surface):
- 如果 H H H 是正定的(所有特征值都大于 0),则 θ ′ \theta'