指数滑动平均(ExponentialMovingAverage)EMA


EMA被广泛的应用在深度学习的BN层中,RMSprop,adadelta,adam等梯度下降方法


tf.train.ExponentialMovingAverage
函数定义
tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型,他使用指数衰减来计算变量的移动平均值。
tf.train.ExponentialMovingAverage.init(self, decay, num_updates=None, zero_debias=False, name="ExponentialMovingAverage"):
decay是衰减率在创建ExponentialMovingAverage对象时,需指定衰减率(decay),用于控制模型的更新速度。decay设置为接近1的值比较合理,通常为:0.999,0.9999。这里的一个trick是,

例如,

0.95^(20)=0.3584
1/e=0.3678
两者大概是近似相等的,也许这就是指数滑动平均中指数的含义吧。
影子变量的初始值与训练变量的初始值相同。当运行变量更新时,每个影子变量都会更新为:


num_updates是ExponentialMovingAverage提供用来动态设置decay的参数,当初始化时提供了参数,即不为none时,每次的衰减率是:

apply()方法添加了训练变量的影子副本,并保持了其影子副本中训练变量的移动平均值操作。在每次训练之后调用此操作,更新移动平均值。
average()和average_name()方法可以获取影子变量及其名称。


Tensorflow栗子:

import tensorflow as tf

# 定义一个32位浮点数的变量,初始值位0.0
v1 =tf.Variable(dtype=tf.float32, initial_value=0.)

# 衰减率decay,初始值位0.99
decay = 0.99

# 定义num_updates,同样,初始值位0
num_updates = tf.Variable(0, trainable=False)

# 定义滑动平均模型的类,将衰减率decay和num_updates传入。
ema = tf.train.ExponentialMovingAverage(decay=decay, num_updates=num_updates)

# 定义更新变量列表
update_var_list = [v1]

# 使用滑动平均模型
ema_apply = ema.apply(update_var_list)

# Tensorflow会话
with tf.Session() as sess:
    # 初始化全局变量
    sess.run(tf.global_variables_initializer())

    # 输出初始值
    print(sess.run([v1, ema.average(v1)]))      
    # [0.0, 0.0](此时 num_updates = 0 ⇒ decay = .1, ),
    # shadow_variable = variable = 0.

    # 将v1赋值为5
    sess.run(tf.assign(v1, 5))

    # 调用函数,使用滑动平均模型
    sess.run(ema_apply)

    # 再次输出
    print(sess.run([v1, ema.average(v1)]))     
    # 此时,num_updates = 0 ⇒ decay =0.1,  v1 = 5; 
    # shadow_variable = 0.1 * 0 + 0.9 * 5 = 4.5 ⇒ variable

    # 将num_updates赋值为10000
    sess.run(tf.assign(num_updates, 10000))

    # 将v1赋值为10
    sess.run(tf.assign(v1, 10))

    # 调用函数,使用滑动平均模型
    sess.run(ema_apply)

    # 输出
    print(sess.run([v1, ema.average(v1)]))      
    # decay = 0.99,shadow_variable = 0.99 * 4.5 + .01*10 ⇒ 4.555

    # 再次使用滑动平均模型
    sess.run(ema_apply)

    # 输出
    print(sess.run([v1, ema.average(v1)]))      
    # decay = 0.99,shadow_variable = .99*4.555 + .01*10 = 4.609
    for i in range(1000):
        sess.run(ema_apply)
        print(sess.run([v1,ema.average(v1)]))


  • 11
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值