tensorflow提供了tf.train.saver类来保存还原一个神经网络模型。
1.保存计算图
以下为保存计算图的方法:
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
init_op = tf.global_variables_initializer()
# 声明tf.train.saver类用于保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess, "./checkpoint_dir/MyModel")
然后会生成三个文件。
第一个为model.ckpt.meta,保存了tensorflow计算图的结构,即网络结构。
第二个为model.ckpt,保存了Tensorflow中每一个变量的取值。
第三个为checkpoint文件,保存了一个目录下所有的模型文件列表。
以下为加载方法:
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "./checkpoint_dir/MyModel")
print(sess.run(result))
2.变量重命名
在保存和加载时给变量重命名,使用字典:
saver = tf.train.Saver({"v1":v1,"v2":v2})
从而方便使用滑动平均值(将影子变量映射到变量自身,从而不需要再次调用函数计算)
# 保存滑动平均模型,从而不需要再次计算
v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables():
print(variables.name)
#输出v:0
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
print(variables.name)
#输出v/ExponentialMovingAverage:0
saver = tf.train.Saver()
with tf.Session() as sess:
ini_op = tf.global_variables_initializer()
sess.run(ini_op)
sess.run(tf.assign(v, 10))
sess.run(maintain_averages_op)
saver.save(sess, "./checkpoint_dir/MyModel01")
print(sess.run([v, ema.average(v)]))
#输出[10.0, 0.099999905]
读取变量的滑动平均值:
v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
# variables_to_restore生成字典来通过变量重命名直接读取变量的滑动平均值
saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
saver.restore(sess, "./checkpoint_dir/MyModel")
print(sess.run(v))
variables_to_restore生成字典={“v/ExponentialMovingAverage”: v}
3.将Tensorflow程序放在一个文件中
在测试或离线预测时,只需要知道如何从神经网络输入层计算到输入层,不需要变量初始化,模型保存等辅助信息,而convert_variables_to_constants函数将计算图中的变量及取值通过常量的方式保存到一个文件中,
实例如下:
import tensorflow as tf
from tensorflow.python.framework import graph_util
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# 导出计算图的GraphDef部分,这一部分即可完成从输入层到输出层的计算
graph_def = tf.get_default_graph().as_graph_def()
# 将变量及取值转化为常量,并将不必要的节点去掉。'add'为需要保存的节点名称
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
with tf.gfile.GFile("./simple_model", "wb")as f:
f.write(output_graph_def.SerializeToString())
以下程序可直接得到定义的加法运算的结果,适用于只需要某个节点的值时,
在迁移学习中将得到应用。
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename = "./combined model"
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef() # 读取
graph_def.ParseFromString(f.read()) # 解析成PB文件
# ["add:0"]是张量的名称
result = tf.import_graph_def(graph_def, return_elements=["add:0"])
print(sess.run(result))