作者:chen_h
微信号 & QQ:862251340
微信公众号:coderpai
今天学习了一下Tensorflow模型的保存和加载,查看了API文档,但是没有很理解,所以从网上找了一个比较简单的实现。
比如,我们需要保存的模型是参数v1
和v2
,那么只需要使用下列的保存代码save_model.py
。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
v1 = tf.Variable(1.1, name="v1")
v2 = tf.Variable(1.2, name="v2")
init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
print v2.eval(sess)
save_path="model.ckpt"
saver.save(sess,save_path)
print "Model stored...."
如果,我们要恢复模型,并且把他们导入到变量中,那么首先定义两个参数v3
和v4
,给他们取名叫v1
和v2
。注意,这里必须要给v3
和v4
取名为v1
和v2
,因为我们保存的模型中给变量取的名字就是v1
和v2
。那么,模型恢复的代码为restore_model.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
v3 = tf.Variable(0.0, name="v1")
v4 = tf.Variable(0.0, name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
save_path="model.ckpt"
saver.restore(sess, save_path)
print("Model restored.")
print sess.run(v3)