作用:使用滑动平均模型可以使模型在测试数据上更健壮。即如果在测试过程中,出现了一些噪声数据,滑动平均模型可以很好地应对这些数据,使这些噪声数据不会对模型的变量造成太大的影响。
1.滑动平均模型原理:
在创建滑动平均模型后,滑动平均模型会对每一个变量维护一个影子变量(shadow variable),影子变量的初始值为相应变量的初始值,每当变量更新时,影子变量的值会更新为:
shadow_variable = decay * shadow_variable + (1-decay)*variable
shadow_variable为 影子变量,variable为待更新的变量,decay为衰减率,衰减率越大模型越稳定,因为从上实在可以看出,衰减率越大,影子变量受变量更新的影响越小。在实际应用中,decay一般会设置成非常接近1的数(如0.999或0.9999)。
[注意!变量的影子变量和变量的滑动平均值是一样的!]
滑动平均可以看作是变量的过去一段时间取值的均值,相比对变量直接赋值而言,滑动平均得到的值在图像上更加平缓光滑,抖动性更小,不会因为某次的异常取值而使得滑动平均值波动很大。
而滑动平均为什么会在测试中被使用呢?
滑动平均可以使模型在测试数据上更健壮(robust)。“采用随机梯度下降算法训练神经网络时,使用滑动平均在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。”
对神经网络边的权重 weights 使用滑动平均,得到对应的影子变量 shadow_weights。在训练过程仍然使用原来不带滑动平均的权重 weights,不然无法得到 weights 下一步更新的值,又怎么求下一步 weights 的影子变量 shadow_weights。之后在测试过程中使用 shadow_weights 来代替 weights 作为神经网络边的权重,这样在测试数据上效果更好。因为 shadow_weights 的更新更加平滑,对于随机梯度下降而言,更平滑的更新说明不会偏离最优点很远;
设decay=0.999,一个更直观的理解,在最后的1000次训练过程中,模型早已经训练完成,正处于抖动阶段,而滑动平均相当于将最后的1000次抖动进行了平均,这样得到的权重会更加robust。在整个训练过程中影子变量并不会对实际需要训练的变量产生影响啊,后面持久化的变量也不是影子变量。 在训练过程中,为参数维护更新一个影子变量,这样影子变量会停留在最终参数的周围保持稳定。 在测试阶段,使用影子变量代替参数,进行测试。
2.定义滑动平均模型:
tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型。
形式:tf.train.ExponentialMovingAverage(decay, num_updates=None, name="ExponentialMovingAverage")这个
其中有两个较为重要的参数:
decay :必填,为衰减率,用于控制模型更新的速度。
num_updates:选填,默认为none。用于控制衰减率decay的变化,若num_updates为none,则衰减率不变。当使用num_updates后,衰减率就为:
可以看出,num_updates越大,衰减率就越大。num_updates一般会为迭代轮数,所以当迭代轮数越大,模型参数就越稳定。
3.代码
import tensorflow as tf
if __name__ == "__main__":
#定义一个变量用于计算滑动平均,变量的初始值为0
v1 = tf.Variable(5,dtype=tf.float32)
#定义一个迭代轮数的变量,动态控制衰减率,并设置为不可训练
step = tf.Variable(10,trainable=False)
#定义一个滑动平均类,初始化衰减率为0.99和衰减率的变量step
ema = tf.train.ExponentialMovingAverage(0.99,step)
#定义每次滑动平均所更新的列表
maintain_average_op = ema.apply([v1])
#初始化上下文会话
with tf.Session() as sess:
#初始化所有变量
init = tf.initialize_all_variables()
sess.run(init)
#更新v1的滑动平均值
'''
衰减率为min(0.99,(1+step)/(10+step)=0.1}=0.1
'''
sess.run(maintain_average_op)
#[5.0, 5.0]
print(sess.run([v1,ema.average(v1)]))
sess.run(tf.assign(v1,4))
sess.run(maintain_average_op)
#[4.0, 4.5500002],5*(11/20) + 4*(9/20)
print(sess.run([v1, ema.average(v1)]))
'''在实际中,v1变量很经常是网络中的权重值weights'''