滑动平均算法

为什么要使用滑动平均模型?
因为在神经网络中, 更新的参数时候不能太大也不能太小,更新的参数跟你之前的参数有联系,不能发生突变。一旦训练的时候遇到个“疯狂”的参数,有了滑动平均模型,疯狂的参数就会被抑制下来,回到正常的队伍里。这种对于突变参数的抑制作用,用专业术语讲叫鲁棒性,鲁棒性就是对突变的抵抗能力,鲁棒性越好,这个模型对恶性参数的提抗能力就越强。

参数decay:滑动平均模型通过控制衰减率(decay)来控制参数更新前后之间的差距,从而达到减缓参数的变化幅度的目的(如,参数更新前是5,更新后的值是4,通过滑动平均模型之后,参数的值会在4到5之间),如果参数更新前后的值保持不变,通过滑动平均模型之后,参数的值仍然保持不变。
**num_updates:**默认为None,如果设置了num_updates,那么这个参数就代表了模型更新的次数,每次使用的衰减率将会按照如下公式更新:
decay = min{init_decay , (1 + num_update) / (10 + num_update)}
随着 num_update 更新次数的增加,(1 + num_update) / (10 + num_update 这一项的计算结果越接近1

此时原模型参数按照以下公式更新:
shadow_variable = decay * shadow_variable + (1 - decay) * variable
其中 shadow_variable 为变量更新前的数值,我们叫它影子变量,variable为变量更新后的数值,上面我们说到,随着num_updates的增大,decay的计算结果越来越接近1,那么1-decay就趋近于0,即,模型参数更新变得越来越慢。

这样,我们就可以通过对num_updates的控制来使得模型在训练初期参数更新幅度加大,在接近最优值处参数更新幅度减小,即在减少训练时间的基础上保证模型训练的精度。

如何使用滑动平均模型?

import tensorflow as tf

#定义一个变量用于计算滑动平均,初始值为 0 ,类型为实数
v1 = tf.Variable(0, dtype = tf.float32)
#step变量用来模拟神经网络中迭代的轮数,即我们上面说的num_updates参数,用来动态控制衰减率
step = tf.Variable(0, trainable = False)

#定义一个滑动平均类,初始化衰减率(0.99)和衰减率控制变量step
#该函数返回一个ExponentialMovingAverage对象,该对象调用apply方法可以通过滑动平均模型来更新参数
ema = tf.train.ExponentialMovingAverage(0.99, step)

#定义一个更新变量滑动平均的操作。
#这里的给定数据需要是列表的形式,每次执行这个操作时列表中的变量都会被更新
maintain_averages_op = ema.apply([v1])

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
#通过ema.average(v1)获取滑动平均之后变量的取值,此处输出为[0.0 , 0.0]
#初始化之后变量v1和v1的滑动平均都为0
    print sess.run([v1, ema.average(v1)])

#更新变量v1的值为5
    sess.run(tf.assign(v1, 5))
#更新v1的滑动平均值
#此时衰减率为min(0.99,(1+step)/(10+step)=0.1) = 0.1
#所以v1的滑动平均会被更新为0.1*0 + 0.9*5 = 4.5
    sess.run(maintain_averages_op)
    print sess.run([v1, ema.average(v1)])   #输出[5.0 , 4.5]

#更新step的值为10000
    sess.run(tf.assign(step, 10000))
#更新v1的值为10
     sess.run(tf.assign(v1, 10))
#计算v1的滑动平均值
#此时衰减率为min(0.99,(1+step)/(10+step)=0.999999) = 0.99
#所以v1的滑动平均会被更新为0.99*4.5 + 0.01*10 = 4.555
    sess.run(maintain_averages_op)
    print sess.run([v1, ema.average(v1)])   #输出[10.0, 4.5549998]  
       
#再次更新滑动平均值,得到的新的滑动平均值为0.99*4.555 + 0.01*10 = 4.60945
    sess.run(maintain_averages_op)
    print sess.run([v1, ema.average(v1)])    #输出[10.0, 4.6094499]
    

吴恩达老师Deeplerning课程对于滑动平均算法有更深的讲解。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值