tf.train.ExponentialMovingAverage用法

tf.train.ExponentialMovingAverage:通过采用指数衰减保持变量的移动平均值

tf.train.ExponentialMovingAverage(
    decay, num_updates=None, zero_debias=False, name='ExponentialMovingAverage'
)

训练模型时,保持训练参数的移动平均值通常是有益的。 使用平均参数的评估有时会产生比最终训练值明显更好的结果。

apply()方法添加训练变量的影子副本,并添加操作在其影子副本中保持训练变量的移动平均值。 在构建训练模型时使用它。维持移动平均值的操作通常在每个训练步骤之后执行 average()average_name()方法可访问影子变量及其名称。在构建评估模型从检查点文件还原模型时,它们很有用。 他们有助于使用移动平均值代替上次训练的值进行评估。

移动平均值是使用指数衰减来计算的。 在创建ExponentialMovingAverage对象时,可以指定衰减值。

影子变量使用与训练变量相同的初始值进行初始化。 当运行ops来维持移动平均值时,每个影子变量都会使用以下公式进行更新:

shadow_variable -= (1 - decay) * (shadow_variable - variable)

从数学上讲,这等效于下面的经典公式,但是使用assign_sub 操作(公式中的“-=”)允许并发无锁更新变量:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

合理的衰减值接近1.0,通常在多个九度范围内:0.999、0.9999等。

使用方法:ExponentialMovingAverage()创建一个新的ExponentialMovingAverage对象。必须调用`apply()`方法来创建影子变量并添加操作以维持移动平均值。可选的num_updates参数允许动态调整衰减率。 通常要传递训练步骤的数量,通常保持在每个步骤中递增的变量中,在这种情况下,衰减速率在训练开始时会较低。 这使移动平均值移动得更快。 如果传递,则使用的实际衰减率是:

min(decay, (1 + num_updates) / (10 + num_updates))

示例程序:

import tensorflow as tf

v1 = tf.Variable(0, dtype=tf.float32)
step = tf.Variable(tf.constant(0))

ema = tf.train.ExponentialMovingAverage(0.99, step)  
# 创建一个新的ExponentialMovingAverage对象ema
maintain_average = ema.apply([v1])  
# 调用apply()方法来创建变量v1的影子变量,并添加操作以维持移动平均值

with tf.Session() as sess:
    init = tf.initialize_all_variables()  # 定义初始化变量操作
    sess.run(init)  # 执行初始化变量操作

    print(sess.run([v1, ema.average(v1), ema.average_name(v1)]))  
    # 初始的值都为0,average()和average_name()方法可访问变量v1的影子变量及其名称

    sess.run(tf.assign(v1, 5))  # 把v1变为5
    sess.run(maintain_average)  # 执行maintain_average
    print(sess.run([v1, ema.average(v1), ema.average_name(v1)]))  
    # decay=min(0.99, 1/10)=0.1, v1_shadow=0.1*0+0.9*5=4.5

    sess.run(tf.assign(step, 10000))  # steps=10000
    sess.run(tf.assign(v1, 10))  # v1=10
    sess.run(maintain_average)
    print(sess.run([v1, ema.average(v1), ema.average_name(v1)]))
    # decay=min(0.99,(1+10000)/(10+10000))=0.99,v1_shadow=0.99*4.5+0.01*10=4.555

    sess.run(maintain_average)
    print(sess.run([v1, ema.average(v1), ema.average_name(v1)]))
    # decay=min(0.99,(1+10000)/(10+10000))=0.99,
    # v1_shadow=0.99*4.555+0.01*10=4.609449999999999

# 输出结果:
# [0.0, 0.0, None]
# [5.0, 4.5, None]
# [10.0, 4.555, None]
# [10.0, 4.60945, None]

 

import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进
05-26
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值