Tensorflow保存模型,恢复模型,使用训练好的模型进行预测和提取中间输出(特征)【转】

来自:http://blog.csdn.net/ying86615791/article/details/72731372

前言:
tensorflow中有operation和tensor,前者表示 操作 ,后者表示 容器 ,每个operation都是有一个tensor来存放值的,比如y=f(x), operation是f(x), tensor存放的就是y,如果要获取y,就必须输入x
tensor的名字一般是 <operation>:<num>

可以通过 print(out.name) 来看看


假如之前的训练定义了如下图(模型),并保存:

[python]  view plain  copy
  1. ....  
  2. bottom = layers.fully_connected(inputs=bottom, num_outputs=7, activation_fn=None, scope='logits_classifier')  
  3. ......  
  4. prediction = tf.nn.softmax(logits, name='prob')  
  5. ......  
  6. saver_path = './model/checkpoint/model.ckpt'  
  7. saver = tf.train.Saver()  
  8. config = tf.ConfigProto()  
  9. config.gpu_options.allow_growth=True  
  10. with tf.Session(config=config) as sess:  
  11.     sess.run(init)  
  12. ...  
  13.     saved_path = saver.save(sess,saver_path) # 这个保存了三个东西, .meta是图的结构, 还有两个是模型中变量的值  
  14. ...  

要想图结构和模型(恢复图结构,没错,从空白的代码段中恢复一个graph,就不需要重新定义图了)

[python]  view plain  copy
  1.    meta_path = './model/checkpoint/model.ckpt.meta'  
  2.    model_path = './model/checkpoint/model.ckpt'  
  3.    saver = tf.train.import_meta_graph(meta_path) # 导入图  
  4.      
  5.    config = tf.ConfigProto()  
  6.    config.gpu_options.allow_growth=True  
  7.    with tf.Session(config=config) as sess:  
  8.        saver.restore(sess, model_path) # 导入变量值  
  9.        graph = tf.get_default_graph()  
  10.        prob_op = graph.get_operation_by_name('prob'# 这个只是获取了operation, 至于有什么用还不知道  
  11. prediction = graph.get_tensor_by_name('prob:0'# 获取之前prob那个操作的输出,即prediction  
  12. print( ress.run(prediciton, feed_dict={...})) # 要想获取这个值,需要输入之前的placeholder  
  13.        print(sess.run(graph.get_tensor_by_name('logits_classifier/weights:0'))) # 这个就不需要feed了,因为这是之前train operation优化的变量,即模型的权重  


关于获取保存的模型中的tensor或者输出,还有一种办法就是用tf.add_to_collection(),
假如上面每次定义一次运算后,可以在后面添加tf.add_to_collection():

[python]  view plain  copy
  1. ......  
  2. bottom = layers.fully_connected(inputs=bottom, num_outputs=7, activation_fn=None, scope='logits_classifier')  
  3. ### add collection  
  4. tf.add_to_collection('logits',bottom)  
  5. ......  
  6. prediction = tf.nn.softmax(logits, name='prob')  
  7. ### add collection  
  8. tf.add_to_collection('prob',prediction)  
  9. ......  

恢复模型后,通过tf.get_collection()来获取tensor:

[python]  view plain  copy
  1. ......  
  2. x = tf.get_collection('inputs')[0]  
  3. prob = tf.get_collection('prob')[0]  
  4. print(x)  
  5. print(prob)  
  6. .....  
可以查看输出,效果是和上面get_tensor_by_name()一样的,注意get_collection(name)的name只是collection的name,tensor的名字还是原来的名字


得到了模型各个地方的tensor之后,要想获取该地方的参数或者输出的值,只需要通过sess.run()就可以了,参数可以直接run,中间的特征或者预测值需要通过feed_dict={}传递输入的值就行啦



  • 0
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值