variables_to_restore函数的用法

variables_to_restore是为了在保持模型的时候方便使用滑动平均的参数,如果不使用这个保存,那模型就会保存所以参数,除非你提前设定,就是在保存的时候指定保存变量也是可以的,比如saver = tf.train.Saver([v])这样就可以指定保存变量v,在模型导入的时候只有这个变量会被导入。

比如:

import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  

v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())

saver = tf.train.Saver()
with tf.Session() as sess:
	sess.run(tf.initialize_all_variables())

	sess.run(tf.assign(v, 10.0))
	sess.run(maintain_average_op)
	saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
模型导入:

import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  

v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())

saver = tf.train.Saver()
with tf.Session() as sess:
	# sess.run(tf.initialize_all_variables())

	# sess.run(tf.assign(v, 10.0))
	# sess.run(maintain_average_op)
	# saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')

	saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')
	print sess.run(ema.average(v))
	print sess.run(v)
输出:

0.0999999
10.0

这样不是很方便,因为我再次导入模型,变量v的值我不用,并且想要用计算后的值替代v,这样在模型被导入就方便就算

下面代码显示如何使用:

import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  

v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())

saver = tf.train.Saver()
with tf.Session() as sess:
	sess.run(tf.initialize_all_variables())

	sess.run(tf.assign(v, 10.0))
	sess.run(maintain_average_op)
	saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
	print sess.run(v)
	print sess.run(ema.average(v))

	# saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')
	# print sess.run(v)
输出:

10.0
0.0999999


导入模型的时候tf.train.Saver函数要变化一下,变为tf.train.Saver(ema.variables_to_restore()),代码如下:

import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  

v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.all_variables())

saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
	# sess.run(tf.initialize_all_variables())

	# sess.run(tf.assign(v, 10.0))
	# sess.run(maintain_average_op)
	# saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
	# print sess.run(v)
	# print sess.run(ema.average(v))

	saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')
	print sess.run(v)
输出:

0.0999999


注意:如果不变的话,那么输出就会是10!




  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进
最新发布
05-26

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值