常见优化器SGD-NadaMax
转载 https://zhuanlan.zhihu.com/p/81020717
- SGD
- Momentum
- Nesterov Momentum
- AdaGrad
- RMSProp
- AdaDelta
- Adam
- AdaMax
- Nadam
- NadaMax
SGD
随机梯度下降算法,最常用地mSGD,小批量随机梯度下降求解参数。
Momentum
动量梯度下降算法
代码
import numpy as np
class Momentum(object):
def __init__(self,alpha=0.0,lr=1e-3):
self.alpha = alpha
self.lr = lr
self.v = 0
def update(self,g:np.ndarray):
self.v = self.alpha*self.v-self.lr*g
return self.v
思想:梯度下降不仅与当前所求梯度有关,还与上次该参数所求梯度有关,按照一定关系合并 (\alpha)
公式:
Nesterov Momentum
思想:Nesterov先用当前的速度v_{i} 更新一遍参数,得到一个临时参数 w_{x} ,然后使用这个临时参数计算本轮训练的梯度
公式:
代码
AdaGrad
adaptive subgradient,主要特点是不断累加每次训练中梯度地平方
公式:
g_{i}^2 表示地是 矩阵地哈达玛积
思想:随着算法迭代,r会越来越大,整体地学习率会越来越小,故一开始是激励收敛,后面变成惩罚收敛。
代码:
class AdaGrad(object):
def __init__(self, eps=1e-8, lr=1e-3):
self.r = eps # r_0 = epsilon
self.lr = lr
def update(self, g: np.ndarray):
r = r + np.square(g)
return -self.lr * g / np.sqrt(r)
RMSProp
RMSProp是AdaGrad的改进算法,其公式和AdaGrad的区别只有 r_{i}的计算不同
公式:
思想 ;RMSProp只会累积近期的梯度信息,对于“遥远的历史”会以指数衰减的形式放弃,AdaGrad算法虽然在凸函数(Convex Functions)上表现较好,但是当目标函数非凸时,算法梯度下降的轨迹所经历的结构会复杂的多,早期梯度对当前训练没有太多意义。
代码
class RMSProp(object):
def __init__(self,lr=1e-3,beta=0.999,eps=1e-8):
self.r = eps
self.lr = lr
self.beta = beta
def update(self,g:np.bdarray):
r =r*self.beta+(1-self.beta)*np.square(g)
return -self.lr*g/np.sqrt(r)
AdaDelta
AdaDelta是与RMSProp相同时间对立发展出来的一个算法
公式:
思想:算法不需要设置学习率,同样以 r_{i}来累积梯度的信息之外,该算法还多了一个 s_{i} 以指数衰减的形式来累积 \Delta w 的信息
初始化r,s最小值
代码
class AdaDelta(object):
def __init__ (self,beta=0.999,eps = 1e-8):
self.r =eps
self.s = eps
self.beta = beta
def update(self,g:np.ndarray):
g_square = (1-self.beta)*np.square(g)
r =r * self.beta+g_square
frac = s/r
res = -np.sqrt(frac) *g
s = s*self.beta+frac*g_squeare
return res
Adam
Adam的名称来自Adaptive Momentum
公式
思想:可以看作是Momentum与RMSProp的一个结合体,该算法通过计算梯度的一阶矩估计和二阶矩估计而为不同的参数设计独立的自适应性学习率
代码:
class Adam(object):
def __init__(self, lr=1e-3, alpha=0.9, beta=0.999, eps=1e-8):
self.s = 0
self.r = eps
self.lr = lr
self.alpha = alpha
self.beta = beta
self.alpha_i = 1
self.beta_i = 1
def update(self, g: np.ndarray):
self.s = self.s * self.alpha + (1-self.alpha) * g
self.r = self.r * self.beta + (1-self.beta) * np.square(g)
self.alpha_i *= self.alpha
self.beta_i *= self.beta_i
lr = -self.lr * (1-self.beta_i)**0.5 / (1-self.alpha_i)
return lr * self.s / np.sqrt(self.r)
AdaMax
公式
思想 :max比较的是梯度各个维度上的当前值和历史最大值
w的各维度的增量是根据该维度上梯度的 L_2 范数的累积量进行缩放的。如果用 L_p范数替代就得到了Adam的不同变种,不过其中 L_p范数对应的变种算法简单且稳定
代码
class AdaMax(object):
def __init__(self, lr=1e-3, alpha=0.9, beta=0.999):
self.s = 0
self.r = 0
self.lr = lr
self.alpha = alpha
self.alpha_i = 1
self.beta = beta
def update(self, g: np.ndarray):
self.s = self.s * self.alpha + (1-self.alpha) * g
self.r = np.maximum(self.r*self.beta, np.abs(g))
self.alpha_i *= self.alpha
lr = -self.lr / (1-self.alpha_i)
return lr * self.s / self.r
Nadam
思想:Adam可以看作是Momentum与RMSProp的结合,既然Nesterov的表现较Momentum更优,那么自然也就可以把Nesterov Momentum与RMSProp组合到一起
公式:
代码
class Nadam(object):
def __init__(self, lr=1e-3, alpha=0.9, beta=0.999, eps=1e-8):
self.s = 0
self.r = eps
self.lr = lr
self.alpha = alpha
self.beta = beta
self.alpha_i = 1
self.beta_i = 1
def update(self, g: np.ndarray):
self.s = self.s * self.alpha + (1-self.alpha) * g
self.r = self.r * self.beta + (1-self.beta) * np.square(g)
self.alpha_i *= self.alpha
self.beta_i *= self.beta_i
lr = -self.lr * (1-self.beta_i)**0.5 / (1-self.alpha_i)
return lr * (self.s * self.alpha + (1-self.alpha) * g) / np.sqrt(self.r)
NadaMax
思想:Nesterov与AdaMax结合变成NadaMax
公式:
代码
class NadaMax(object):
def __init__(self, lr=1e-3, alpha=0.9, beta=0.999):
self.s = 0
self.r = 0
self.lr = lr
self.alpha = alpha
self.alpha_i = 1
self.beta = beta
def update(self, g: np.ndarray):
self.s = self.s * self.alpha + (1-self.alpha) * g
self.r = np.maximum(self.r*self.beta, np.abs(g))
self.alpha_i *= self.alpha
lr = -self.lr / (1-self.alpha_i)
return lr * (self.s * self.alpha + (1-self.alpha) * g) / self.r
参考资料
[1]: 《机器学习算法背后的理论与优化》 ISBN 978-7-302-51718-4
[2]: Adam: A Method for Stochastic Optimization
[3]: Incorporating Nesterov Momentum into Adam
[4]: An overview of gradient descent optimization algorithms