tensorflow保存模型和加载模型的方法
参考:
1、首先计算图定义代码如下:
import tensorflow as tf
a = tf.placeholder(dtype=tf.float64,shape=[1,2],name="input_data")
b = tf.Variable(initial_value=1.0,dtype=tf.float64,name="b")
b = tf.assign(b,2)
out = tf.add(a,b,name="output_data")
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess,save_path="./model/demo")
# write_meta_graph和write_state置为False,则不保存checkpoint和.meta
# saver.save(sess,save_path="./model/demo",
# write_meta_graph=False,
# write_state=False)
执行之后,在当前目录生成了一个model目录,包含4个文件,如下所示。
- 后缀为meta的文件保存了图结构,可以用tf.train.import_meta_graph加载保存的图结构。
- 后缀为data和index的文件保存了网络图中变量的值,通过saver = tf.train.Saver加载参数的时候提供这两个文件即可,因此可以把write_meta_graph和write_state置为False(请看上面的代码),则不保存checkpoint和.meta。
- placeholder中的值不会保存。
2、加载保存的计算图
tf.train.Saver提供了一个restore()方法,恢复之前保存的变量,但是使用restore()之前必须将计算图定义一遍,也可以通过tf.train.import_meta_graph方法从.meta文件中读取图结构。
下面的例子在加载变量之前重新实现了一遍图结构。
import tensorflow as tf
import numpy as np
# 模型定义
a = tf.placeholder(dtype=tf.float64,shape=[1,2],name="input_data")
b = tf.Variable(initial_value=1.0,dtype=tf.float64,name="b")
b = tf.assign(b,2)
out = tf.add(a,b,name="output_data")
# 模型定义结束
saver = tf.train.Saver()
with tf.Session() as sess:
model_file=tf.train.latest_checkpoint("./model/")
saver.restore(sess,model_file)
input_data = np.array([[10, 10]])
print(sess.run("output_data:0", feed_dict={"input_data:0": input_data}))
下面的例子是从.meta文件中加载图结构。
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
structure = tf.train.import_meta_graph("./model/demo.meta")
structure.restore(sess,tf.train.latest_checkpoint("./model/"))
input_data = np.array([[10,10]])
out = sess.run("output_data:0",feed_dict={"input_data:0":input_data})
print(sess.run("b:0"))
print(out)