目录
说明
模型每次反向传导都会给各个可学习参数p计算出一个偏导数,用于更新对应的参数p。通常偏导数不会直接作用到对应的可学习参数p上,而是通过优化器做一下处理,得到一个新的值,处理过程用函数F表示(不同的优化器对应的F的内容不同),即,然后和学习率lr一起用于更新可学习参数p,即。
Adam是在RMSProp和AdaGrad的基础上改进的。先掌握RMSProp的原理,就很容易明白Adam了。本文是在RMSProp这篇博客的基础上写的。
Adam原理
在RMSProp的基础上,做两个改进:梯度滑动平均和偏差纠正。
梯度滑动平均
在RMSProp中,梯度的平方是通过平滑常数平滑得到的,即(根据论文,梯度平方的滑动均值用v表示;根据pytorch源码,Adam中平滑常数用的是β,RMSProp中用的是α),但是并没有对梯度本身做平滑处理。
在Adam中,对梯度也做了平滑,平滑后的滑动均值用m表示,即,在Adam中有两个β。
偏差纠正
上述m的滑动均值的计算,当时,,由于的初始是0,且β接近1,因此t较小时,m的值是偏向于0的,v也是一样。这里通过除以来进行偏差纠正,即。
Adam计算过程
为方便理解,以下伪代码和论文略有差异,其中蓝色部分是比RMSProp多出来的。
- 初始:学习率 lr
- 初始:平滑常数(或者叫做衰减速率) ,分别用于平滑m和v
- 初始:可学习参数
- 初始:
- while 没有停止训练 do
- 训练次数更新:
- 计算梯度:(所有的可学习参数都有自己的梯度,因此 表示的是全部梯度的集合)
- 累计梯度:(每个导数对应一个m,因此m也是个集合)
- 累计梯度的平方:(每个导数对应一个v,因此v也是个集合)
- 偏差纠正m:
- 偏差纠正v:
- 更新参数:
- end while
pytorch Adam参数
torch.optim.Adam(params,
lr=0.001,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0,
amsgrad=False)
params
模型里需要被更新的可学习参数
lr
学习率
betas
平滑常数和
eps
,加在分母上防止除0
weight_decay
weight_decay的作用是用当前可学习参数p的值修改偏导数,即:,这里待更新的可学习参数p的偏导数就是
weight_decay的作用是L2正则化,和Adam并无直接关系。
amsgrad
如果amsgrad为True,则在上述伪代码中的基础上,保留历史最大的,记为,每次计算都是用最大的,否则是用当前。
amsgrad和Adam并无直接关系。