#测试模型持久化
v1 = tf.Variable(tf.constant(1.,shape=[2,2]),name='v1')
v2 = tf.Variable(tf.constant(1.,shape=[2,2]),name='v2')
result = v2 + v1
init_op = tf.global_variables_initializer()
#声明tf.train.Saver类用于保存模型
saver=tf.train.Saver()
with tf.compat.v1.Session() as sess:
sess.run(init_op)
saver.save(sess,"papay\\1.ckpt")
生成文件列表如下
文件说明:checkpoint-保存了目录下所有模型文件列表 meta-保存了计算图的结果,index和data文件保存了变量的取值
#测试模型加载-重复运算
v1 = tf.Variable(tf.constant(1.,shape=[2,2]),name='v1')
v2 = tf.Variable(tf.constant(1.,shape=[2,2]),name='v2')
result = v2 + v1
saver=tf.train.Saver()
with tf.compat.v1.Session() as sess:
saver.restore(sess,"papay\\1.ckpt")
sess.run(result)
#测试模型加载-直接加载持久化的计算图,无需重复运算
saver = tf.train.import_meta_graph("papay\\1.ckpt.meta")
with tf.compat.v1.Session() as sess:
saver.restore(sess,"papay\\1.ckpt")
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
#测试变量重命名
v1 = tf.Variable(tf.constant(1.,shape=[2,2]),name='v100')#将变量名从v1重命名为v100
v2 = tf.Variable(tf.constant(1.,shape=[2,2]),name='v200')
result = v2 + v1
saver=tf.train.Saver({"v1":v1,"v2":v2})#将原来名称为v1的变量加载到v1中
with tf.compat.v1.Session() as sess:
saver.restore(sess,"papay\\1.ckpt")
sess.run(result)
#测试变量重命名在滑动平均中的应用
v = tf.Variable(1.,dtype=tf.float32,name='v')
#没有加入滑动平均,输出为"v:0",0代表计算节点的第一次输出,详见张量
for variables in tf.global_variables():
print(variables.name)
#使用滑动平均后,会自动生成一个影子变量"v/ExponentialMovingAverage:0"
ema = tf.train.ExponentialMovingAverage(0.9)
avg_op = ema.apply(tf.global_variables())
#输出"v:0"和"v/ExponentialMovingAverage:0"
for variables in tf.global_variables():
print(variables.name)
saver = tf.train.Saver()
with tf.compat.v1.Session() as sess:
init_op = sess.run(tf.global_variables_initializer())
sess.run(tf.assign(v,10))
sess.run(avg_op)
saver.save(sess,"avg\\1.ckpt")
print(sess.run([v,ema.average(v)]))
#加载持久化文件测试
v = tf.Variable(1.,dtype=tf.float32,name='v')
#调用ema.variables_to_restore()等价于saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
saver = tf.train.Saver(ema.variables_to_restore())
with tf.compat.v1.Session() as sess:
saver.restore(sess,"avg\\1.ckpt")
#输出1.9000002
print(sess.run(v))
#测试通过graph_util.convert_variables_to_constants将变量保存为常量,保存为pb文件
from tensorflow.python.framework import graph_util
v1 = tf.Variable(tf.constant(1.,shape=[2,2],dtype=tf.float32,name='v1'))
v2 = tf.Variable(tf.constant(1.,shape=[2,2],dtype=tf.float32,name='v2'))
result = v1 + v2
with tf.compat.v1.Session() as sess:
sess.run(tf.global_variables_initializer())
#仅导出计算图的graphdef部分,就可以完成输入层到输出层的计算过程
graph_def = tf.get_default_graph().as_graph_def()
#将图中变量转化为常量,add为需要保存的节点
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,["add"])
#导出模型到文件,注意,不会自动生成目录
with tf.gfile.GFile("1.pb","wb") as f:
f.write(output_graph_def.SerializeToString())
#模型加载
from tensorflow.python.platform import gfile
with tf.compat.v1.Session() as sess:
model_filename = "1.pb"
#读取模型文件,将文件解析为对应的GraphDef Protocal Buffer
with gfile.FastGFile(model_filename,"rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
#加载时使用张量名称
result = tf.import_graph_def(graph_def,return_elements=["add:0"])
print(sess.run(result))