Nesterov Momentum 工程实现上的trick

Nesterov Momentum 工程实现上的trick

Nesterov Momentum是momentum这种优化方法的一个变种,其参数更新规则这样的:

vαv+grad(θ+αv)θθlrv

参数更新规则这样写有一个问题。一般情况下,(以tensorflow为例)optimizationMethod所接受的参数只有计算好的 grad(θ) θ , 那么我们怎么计算 grad(θ+αv) 的值呢?

查看一下tensorflow的代码,ApplyMomentum 的实现非常简单:

template <typename T>
struct ApplyMomentum<CPUDevice, T> {
  void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
                  typename TTypes<T>::Flat accum,
                  typename TTypes<T>::ConstScalar lr,
                  typename TTypes<T>::ConstFlat grad,
                  typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
    accum.device(d) = accum * momentum() + grad;
    if (use_nesterov) {
      var.device(d) -= grad * lr() + accum * momentum() * lr();
    } else {
      var.device(d) -= accum * lr();
    }
  }
};

tensorflow_source_code

但是看起来跟标准的update rule 完全不一样啊。

解释这个问题还是得借用Hinton老爷子的slides。

这里写图片描述

对于标准的update rule, 一个iteration走过的路径是 012 01 对应 αv , 12 对应 grad(θ+αv) . 但是我们只知道0点的gradient, 不知道1这点的gradient。

既然我们只需要1这点的gradient,那么就把传进来的gradient当成1点的gradient好了。这样,每个iteration的路径就变成了 123 。OK, 假设一个iteration结束了,这时的weight和gradient都是在1点的weight和gradient,而 v 则是0点的v。那么下一个iteration, 12 就是 grad(θ) 。要计算 23 , 我们需要知道2点的 v 。怎么算呢?v2αv0+grad(θ1)。然后 23 就变成了 αv2

因此,我们的update rule变成了:

vαv+grad(θ)θθlr(grad(θ)+αv)

这就和tensorflow的实现一致了。

通过上面的分析,我们知道,如果按照新的update rule计算出的当前的 θ v 的位置是正确的话(和老的update rule得到结果是一样的),那么下一步新老update rule就是等价的。那么根据数学归纳法,我们只要证明,两种update rule在第一个iteration是等价的,就可以确认两种update rule是等价的了。

在第一个iteration时v=0,因此 01 是一个0向量,那么此时的weight θ 既可以理解成0点weight也可以理解成1点的weight。因此两种update rule在初始条件下也是等价。所以,tensorflow的实现是对的。

注:本文参考了这个问题的回答

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值