使用tensorflow保存、加载和使用模型

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

[python]  view plain  copy
  1. #!/usr/bin/env python  
  2. #-*- coding:utf-8 -*-  
  3. ############################  
  4. #File Name: tut1_save.py  
  5. #Author: Wang   
  6. #Mail: wang19920419@hotmail.com  
  7. #Created Time:2017-08-30 11:04:25  
  8. ############################  
  9.   
  10. import tensorflow as tf  
  11.   
  12. # prepare to feed input, i.e. feed_dict and placeholders  
  13. w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1')  # name is very important in restoration  
  14. w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2')  
  15. b1 = tf.Variable(2.0, name = 'bias1')  
  16. feed_dict = {w1:[10,3], w2:[5,5]}  
  17.   
  18. # define a test operation that will be restored  
  19. w3 = tf.add(w1, w2)  # without name, w3 will not be stored  
  20. w4 = tf.multiply(w3, b1, name = "op_to_restore")  
  21.   
  22. #saver = tf.train.Saver()  
  23. saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1)  
  24. sess = tf.Session()  
  25. sess.run(tf.global_variables_initializer())  
  26. print sess.run(w4, feed_dict)  
  27. #saver.save(sess, 'my_test_model', global_step = 100)  
  28. saver.save(sess, 'my_test_model')  
  29. #saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)  
需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。


下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

[python]  view plain  copy
  1. #!/usr/bin/env python  
  2. #-*- coding:utf-8 -*-  
  3. ############################  
  4. #File Name: tut2_import.py  
  5. #Author: Wang   
  6. #Mail: wang19920419@hotmail.com  
  7. #Created Time:2017-08-30 14:16:38  
  8. ############################  
  9.   
  10. import tensorflow as tf  
  11.   
  12. sess = tf.Session()  
  13. new_saver = tf.train.import_meta_graph('my_test_model.meta')  
  14. new_saver.restore(sess, tf.train.latest_checkpoint('./'))  
  15. print sess.run('w1:0')  

使用加载的模型,输入新数据,计算输出,还是直接上代码:

[python]  view plain  copy
  1. #!/usr/bin/env python  
  2. #-*- coding:utf-8 -*-  
  3. ############################  
  4. #File Name: tut3_reuse.py  
  5. #Author: Wang  
  6. #Mail: wang19920419@hotmail.com  
  7. #Created Time:2017-08-30 14:33:35  
  8. ############################  
  9.   
  10. import tensorflow as tf  
  11.   
  12. sess = tf.Session()  
  13.   
  14. # First, load meta graph and restore weights  
  15. saver = tf.train.import_meta_graph('my_test_model.meta')  
  16. saver.restore(sess, tf.train.latest_checkpoint('./'))  
  17.   
  18. # Second, access and create placeholders variables and create feed_dict to feed new data  
  19. graph = tf.get_default_graph()  
  20. w1 = graph.get_tensor_by_name('w1:0')  
  21. w2 = graph.get_tensor_by_name('w2:0')  
  22. feed_dict = {w1:[-1,1], w2:[4,6]}  
  23.   
  24. # Access the op that want to run  
  25. op_to_restore = graph.get_tensor_by_name('op_to_restore:0')  
  26.   
  27. print sess.run(op_to_restore, feed_dict)     # ouotput: [6. 14.]  

在已经加载的网络后继续加入新的网络层:

[python]  view plain  copy
  1. import tensorflow as tf  
  2.   
  3. sess=tf.Session()      
  4. #First let's load meta graph and restore weights  
  5. saver = tf.train.import_meta_graph('my_test_model-1000.meta')  
  6. saver.restore(sess,tf.train.latest_checkpoint('./'))  
  7.   
  8.   
  9. # Now, let's access and create placeholders variables and  
  10. # create feed-dict to feed new data  
  11.   
  12. graph = tf.get_default_graph()  
  13. w1 = graph.get_tensor_by_name("w1:0")  
  14. w2 = graph.get_tensor_by_name("w2:0")  
  15. feed_dict ={w1:13.0,w2:17.0}  
  16.   
  17. #Now, access the op that you want to run.   
  18. op_to_restore = graph.get_tensor_by_name("op_to_restore:0")  
  19.   
  20. #Add more to the current graph  
  21. add_on_op = tf.multiply(op_to_restore,2)  
  22.   
  23. print sess.run(add_on_op,feed_dict)  
  24. #This will print 120.  

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

[python]  view plain  copy
  1. ......  
  2. ......  
  3. saver = tf.train.import_meta_graph('vgg.meta')  
  4. # Access the graph  
  5. graph = tf.get_default_graph()  
  6. ## Prepare the feed_dict for feeding data for fine-tuning   
  7.   
  8. #Access the appropriate output for fine-tuning  
  9. fc7= graph.get_tensor_by_name('fc7:0')  
  10.   
  11. #use this if you only want to change gradients of the last layer  
  12. fc7 = tf.stop_gradient(fc7) # It's an identity function  
  13. fc7_shape= fc7.get_shape().as_list()  
  14.   
  15. new_outputs=2  
  16. weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))  
  17. biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))  
  18. output = tf.matmul(fc7, weights) + biases  
  19. pred = tf.nn.softmax(output)  
  20.   
  21. # Now, you run this with fine-tuning data in sess.run()  

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值