TensorFlow手把手入门之 — TensorFlow保存还原模型的正确方式,Saver的save和restore方法,亲测可用

TensorFlow保存还原模型的正确方式,Saver的save和restore方法,亲测可用

许多TensorFlow初学者想把自己训练的模型保存,并且还原继续训练或者用作测试。但是TensorFlow官网的介绍太不实用,网上的资料又不确定哪个是正确可行的。

今天David 9 就来带大家手把手入门亲测可用的TensorFlow保存还原模型的正确方式,使用的是网上最多的Saver的save和restore方法, 并且把关键点为大家指出。

今天介绍最为可行直接的方式来自这篇Stackoverflow:https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model 亲测可用:

保存模型:

    import tensorflow as tf
    #Prepare to feed input, i.e. feed_dict and placeholders
    w1 = tf.placeholder("float", name="w1")
    w2 = tf.placeholder("float", name="w2")
    b1= tf.Variable(2.0,name="bias")
    feed_dict ={w1:4,w2:8}
    #Define a test operation that we will restore
    w3 = tf.add(w1,w2)
    w4 = tf.multiply(w3,b1,name="op_to_restore")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    #Create a saver object which will save all the variables
    saver = tf.train.Saver()
    #Run the operation by feeding input
    print sess.run(w4,feed_dict)
    #Prints 24 which is sum of (w1+w2)*b1 
    #Now, save the graph
    saver.save(sess, 'my_test_model',global_step=1000)

必须强调的是:这里4,5,6,11行中的name=’w1′, name=’w2′,  name=’bias’, name=’op_to_restore’ 千万不能省略,这是恢复还原模型的关键。

还原模型:

    import tensorflow as tf
    sess=tf.Session()    
    #First let's load meta graph and restore weights
    saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    # Access saved Variables directly
    print(sess.run('bias:0'))
    # This will print 2, which is the value of bias that we saved
    # Now, let's access and create placeholders variables and
    # create feed-dict to feed new data
    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")
    feed_dict ={w1:13.0,w2:17.0}
    #Now, access the op that you want to run. 
    op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
    print sess.run(op_to_restore,feed_dict)
    #This will print 60 which is calculated

还原当然是用restore方法,这里的18,19,23行就是刚才的name关键字指定的Tensor变量,必须找对才能进行还原恢复。

其他的关键在代码和注释中可以一眼看出, 这里不加赘述了。

参考文献:

  1. https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model
  2. https://nathanbrixius.wordpress.com/2016/05/24/checkpointing-and-reusing-tensorflow-models/
  3. https://stackoverflow.com/questions/42685994/how-to-get-a-tensorflow-op-by-name
  4. http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
  5. https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops/saving_and_restoring_variables


  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值