一.局部最小值与鞍点
在做优化时,我们会发现,随着参数的更新,训练的损失基本上不会再下降,但是我们对此时的损失值依旧不太满意。并且把深度学习得到的loss值与线性模型和浅层网络对比,我们发现深度学习得到的loss值并没有比其他两个更优越,这说明我们的优化可能出了一定的问题
如下图:
在这个图里面,我们看到梯度接近于0,我们首先认为这就是局部最小值(local minimum),但是事实上,还可能是到了鞍点(saddle point)附近,鞍点其实就是梯度是零且区别于局部极小值和局部极大值(localmaximum)的点。我们把梯度为零的点称为临界点下面这幅图可以清楚的表示局部最小值和鞍点的区别:
我们看到,如果梯度下降到局部最小值附近,我们想要继续降低梯度,这可能比较困难;但是如果是下降到鞍点附近,那么继续使梯度下降,就显得很简单了,逃离鞍点,就可能让损失继续降低————所以,探讨临界点是鞍点还是局部最小值就具有了它的意义。
1.2逃离鞍点的方法
如图,在一维平面的曲线里的红点处是局部最小值,但是在二维平面里它却成了鞍点,这说明:
💡 低维度空间中的局部极小值点,在更高维的空间中,实际上可能是鞍点。
如果在一维空间无路可走,我们可以尝试在二维空间看看该点在二维空间是不是鞍点;
如果在二维空间无路可走,我们可以尝试在三维空间看看该点在三维空间是不是鞍点…
按照这个想法类推,在维度极高(成千上万)的时候,局部最小值大多数都可以看作鞍点,只有很少很少的临界点是局部最小值。
在训练一个网络的时候,参数数量随便就能达到成千上千——参数的数量代表了误差表面的维度——这意味着在真正训练的时候鞍点占大多数。
事实上,经过实验实践,我们遇到的大多数都是鞍点,局部最小值反而少见。
2.2批量和动量
在计算梯度的时候,我们一般不是对每一个数据的损失L挨个挨个来计算梯度,而是分成一个又一个的批量(Batch)进行计算,把所有的批量(Batch)计算完一次我们称为经历了一个回合(epoch)
事实上,我们在把数据分为批量时,会进行随机打乱(shuffle)。随机打乱有很多种做法,其中之一就是在每个回合(epoch)开始前重新划分批量——意味着每个回合的批量数据都不一样
2.2.1批量梯度下降法和随机梯度下降法
- 批量梯度下降法(BGD):将整个数据集作为一个批量,这种使用全批量(full batch)的数据来更新参数的方法就是批量梯度下降法。
- 随机梯度下降法(SGD):批量大小等于一,相当于把每个数据遍历。也叫做增量梯度下降法。
对于随机梯度下降:由于每一笔数据都要进行损失计算和参数更新,所以如果有1000笔数据,就会进行1000次参数更新,用一笔数据算出来的损失相对带有更多的噪声,所以它对应得曲线比较曲曲折折。
对于批量梯度下降:批量梯度下降并没有“划分批量”:要把所有的数据都看过一遍,才能够更新一次参数,因此其每次迭代的计算量大。但相比随机梯度下降,批量梯度下降每次更新更稳定、更准确。
💡 随机梯度下降的梯度上引入了随机噪声,因此在非凸优化问题中,其相比批量梯度下降更容易逃离局部最小值。
从这个图里我们可以看到两种方法的区别。
我们讨论两种方法的运行时间的区别:
- 对于批量梯度下降:
实际上,考虑并行运算,批量梯度下降花费的时间不一定更长:
如图,GPU是有并行运算的能力的,但是存在极限。
当批量大小很小时(对于批量梯度下降也就意味着数据集很小),批量梯度下降所耗费的时间是很短的;而当批量大小过大时,所花费的时间会大幅上升。
2. 对于随机梯度下降
我们看到,在批量大小比较小的时候,随机梯度下降过完一整个epoch反而花费了更多的时间;在批量大小较大时,使用随机梯度下降过完一整个epoch花费的时间就相对较少了。
总结:
实际上,在考虑GPU并行计算的能力的时候,小的批量大小使用批量梯度下降比较有效率;大的批量大小使用随机梯度下降比较有效率
大的批量更新比较稳定,小的批量的梯度的方向是比较有噪声的(noisy),但是这并不意味着小的批量所带来的噪声会让训练效果变差,反而这些噪声对训练有帮助!!!
如图,在同一个模型下,Batch Size很大时,准确度很低!!!
附:小的batch对testing data也有很好的效果!!!
2.2.2动量法
假设误差表面就是真正的斜坡,参数是一个球,把球从斜坡上滚下来,如果使用梯度下降,球走到局部最小值或鞍点就停住了。
但是在物理的世界里,一个球如果从高处滚下来,就算滚到鞍点或鞍点,因为惯性的关系它还是会继续往前走。
如果球的动量足够大,其甚至翻过小坡继续往前走。 因此在物理的世界里面,一个球从高处滚下来的时候,它并不一定会被鞍点或局部最小值卡住,如果将其应用到梯度下降中,这就是动量。
一般的梯度下降:
在使用了动量法以后:
(该笔记内容来自于李宏毅老师的《机器学习》及其配套课程,会随夏令营进行而更新)