一、保存模型
例:
import tensorflow as tf
var1 = tf.Variable(tf.constant(1.0,shape=[1]),name='num1')
var2 = tf.Variable(tf.constant(2.0,shape=[1]),name='num2')
var = var1 +var2
#定义存储器
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(var))
saver.save(sess,'/xx/saver/model.ckpt') //将模型保存至XX路径下;
存储后,该路径下将会有以下几个文件:
-rw-rw-r-- 1 xxxx xxxx 143 1月 8 23:45 checkpoint
-rw-rw-r-- 1 xxxx xxxx 8 1月 8 23:27 .data-00000-of-00001
-rw-rw-r-- 1 xxxx xxxx 145 1月 8 23:27 .index
-rw-rw-r-- 1 xxxx xxxx 3578 1月 8 23:27 .meta
-rw-rw-r-- 1 xxxx xxxx 8 1月 8 23:45 model.ckpt.data-00000-of-00001
-rw-rw-r-- 1 xxxx xxxx 145 1月 8 23:45 model.ckpt.index
-rw-rw-r-- 1 xxxx xxxx 3578 1月 8 23:45 model.ckpt.meta
二、加载模型
例:
import tensorflow as tf
#加载图的结构
saver = tf.train.import_meta_graph('/home/abig/vscode/project1/saver/model.ckpt.meta')
with tf.Session() as sess:
#加载变量值,并将变量值存入sess中
saver.restore(sess,'/home/abig/vscode/project1/saver/model.ckpt')
#打印出变量值
print(sess.run(tf.get_default_graph().get_tensor_by_name("num1:0")))
print(sess.run(tf.get_default_graph().get_tensor_by_name("num2:0")))
输出结果为
[ 1.]
[ 2.]