保存滑动平均值的模型
example:
import tensorflow as tf
v = tf.Variable(0, dtype = tf.float32, name = "v")
#没有申明滑动平均模型时只有一个变量v
for variables in tf.all_variables():
pirnt(variable.name)
#输出v:0
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())
#申明滑动平均模型之后,Tensorflow 会自动生成一个影子变量"v/ExponentialMoving Average
for variables in tf.all_variable():
print(variables.name)
#output: v:0, v/ExponentialMoving: 0
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.initialize_all_variables()
sess.run(init_op)
sess.run(tf.assgin(v, 10))
sess.run(maintain_average_op)
saver.save(sess, "/path/to/model/model.ckpt")
#v:o, v/ExeponentialMovingAverage:0,都将保存
print(sess.run([v, ema.average(v)])
#输出:[10.0, 0.099999905]
通过变量重新命名直接读取变量的滑动平均值
v = tf.Variable(0, dtype = tf.float32, name = 'v')
#通过变量重命名 将原来的变量v的ExeponentialMovingAverage 赋值给v
saver = tf.train.Saver({"v/ExeponentialMovingAverage": v})
with tf.Session() as sess:
saver.restore(sess, "/path/to/model/model.ckpt")
print("sess.run(v)")
#输出 0.099999905, 为原来模型中变量v的滑动平均值
为了方便加载时候的重命名, ema.variables_to_restore() 生成tf.train.Saver类所需要的变量重命名字典。
即把saver = tf.train.Saver({“v/ExeponentialMovingAverage”: v})
替换成saver = tf.train.Saver(ema.variables_to_restore())