TensorFlow的文件保存与读取——variables_to_restore函数

转,原创详见: http://blog.csdn.net/sinat_29957455/article/details/78508793

variables_to_restore函数,是TensorFlow为滑动平均值提供。之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮。我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

1、滑动平均值模型文件的保存

[python]  view plain  copy
  1. import tensorflow as tf  
  2.   
  3. if __name__ == "__main__":  
  4.     v = tf.Variable(0.,name="v")  
  5.     #设置滑动平均模型的系数  
  6.     ema = tf.train.ExponentialMovingAverage(0.99)  
  7.     #设置变量v使用滑动平均模型,tf.all_variables()设置所有变量  
  8.     op = ema.apply([v])  
  9.     #获取变量v的名字  
  10.     print(v.name)  
  11.     #v:0  
  12.     #创建一个保存模型的对象  
  13.     save = tf.train.Saver()  
  14.     sess = tf.Session()  
  15.     #初始化所有变量  
  16.     init = tf.initialize_all_variables()  
  17.     sess.run(init)  
  18.     #给变量v重新赋值  
  19.     sess.run(tf.assign(v,10))  
  20.     #应用平均滑动设置  
  21.     sess.run(op)  
  22.     #保存模型文件  
  23.     save.save(sess,"./model.ckpt")  
  24.     #输出变量v之前的值和使用滑动平均模型之后的值  
  25.     print(sess.run([v,ema.average(v)]))  
  26.     #[10.0, 0.099999905]  
上面的代码,是如何来保存一个滑动平均值的模型文件,之前有介绍过滑动平均值和模型文件的保存,所以这里就不再重复了。

2、滑动平均值模型文件的读取

[python]  view plain  copy
  1. v = tf.Variable(1.,name="v")  
  2. #定义模型对象  
  3. saver = tf.train.Saver({"v/ExponentialMovingAverage":v})  
  4. sess = tf.Session()  
  5. saver.restore(sess,"./model.ckpt")  
  6. print(sess.run(v))  
  7. #0.0999999  
对于模型文件的读取,在上一篇博客中有介绍过,这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是 {"v/ExponentialMovingAverage":v}而不是 {"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是 变量本身而不是 影子变量。是不是感觉使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

3、variables_to_restore函数的使用

[python]  view plain  copy
  1. v = tf.Variable(1.,name="v")  
  2. #滑动模型的参数的大小并不会影响v的值  
  3. ema = tf.train.ExponentialMovingAverage(0.99)  
  4. print(ema.variables_to_restore())  
  5. #{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}  
  6. sess = tf.Session()  
  7. saver = tf.train.Saver(ema.variables_to_restore())  
  8. saver.restore(sess,"./model.ckpt")  
  9. print(sess.run(v))  
  10. #0.0999999  
通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。



  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进
最新发布
05-26
以下是对代码的改进建议: 1. 在代码开头添加注释,简要说明代码功能和使用方法。 2. 将导入模块的语句放在代码开头。 3. 将模型保存路径和评估时间间隔定义为常量,并使用有意义的变量名。 4. 将计算正确率和加载模型的过程封装为函数。 5. 在主函数中调用评估函数。 改进后的代码如下: ``` # 该代码实现了使用已训练好的模型对 MNIST 数据集进行评估 import time import tensorflow.compat.v1 as tf from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train # 定义常量 MODEL_SAVE_PATH = 'model/' EVAL_INTERVAL_SECS = 10 def evaluate(mnist): """ 计算模型在验证集上的正确率 """ with tf.Graph().as_default() as g: # 定义输入和输出格式 x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') # 直接调用封装好的函数计算前向传播结果 y = mnist_inference.inference(x, None) # 计算正确率 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 加载模型 variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY) variables_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variables_to_restore) # 在验证集上计算正确率 with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict={x: mnist.validation.images, y_: mnist.validation.labels}) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') def main(argv=None): # 读取数据集 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # 每隔一定时间评估模型在验证集上的正确率 while True: evaluate(mnist) time.sleep(EVAL_INTERVAL_SECS) if __name__ == '__main__': tf.app.run() ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值