保存
用tf.train.Saver
类保存模型
import tensorflow as tf
a = tf.Variable(2, dtype=tf.float32)
print(a) # <tf.Variable 'Variable:0' shape=() dtype=float32_ref>
b = tf.multiply(a, 3)
print(b) # Tensor("Mul:0", shape=(), dtype=float32)
c = tf.add(a, b)
print(c) # Tensor("Add:0", shape=(), dtype=float32)
# 建立saver
saver = tf.train.Saver()
with tf.Session() as sess:
# 变量初始化
sess.run(tf.global_variables_initializer())
# 循环更新变量数值
for i in range(50):
print(sess.run(c))
sess.run(tf.assign_add(a, 1.5))
# 保存模型
# 下方代码中loop.ckpt为保存模型文件的名称
# 模型名称中的.ckpt无特别意义,常用于表明保存checkpoint
# global_step参数用于在保存文件名中加入迭代次数,默认保存最近4次迭代的模型
saver.save(sess, "d:/temp/untitled/loop.ckpt", global_step=i)
"""
在d:/temp/untitled/文件夹下生成13个文件
checkpoint
loop.ckpt-46.meta
loop.ckpt-46.index
loop.ckpt-46.data-00000-of-00001
loop.ckpt-47.meta
loop.ckpt-47.index
loop.ckpt-47.data-00000-of-00001
loop.ckpt-48.meta
loop.ckpt-48.index
loop.ckpt-48.data-00000-of-00001
loop.ckpt-49.meta
loop.ckpt-49.index
loop.ckpt-49.data-00000-of-00001
"""
# 保存模型
# 下方代码中model为保存模型文件的名称
saver.save(sess, "d:/temp/model")
"""
在d:/temp/文件夹下生成4个文件
checkpoint
model.meta
model.index
model.data-00000-of-00001
"""
载入
仅载入模型数据
import tensorflow as tf
# 建立与保存时一致的模型
# 变量初始值随便,因为在下方会从模型保存文件中导入
a = tf.Variable(0.237, dtype=tf.float32)
b = tf.multiply(a, 3)
c = tf.add(a, b)
# 建立saver
saver = tf.train.Saver()
with tf.Session() as sess:
# 载入模型数据
# 注意路径名称与save函数一致,不是填写文件名
saver.restore(sess, "d:/temp/model")
print(sess.run(a))
print(sess.run(b))
print(sess.run(c))
仅载入计算图
import tensorflow as tf
# 载入模型图,不包含数据
# 注意路径指向为save时保存生成的.meta文件
saver = tf.train.import_meta_graph("d:/temp/model.meta")
with tf.Session() as sess:
# 变量初始化
sess.run(tf.global_variables_initializer())
# 取得计算图
graph = tf.get_default_graph()
# 取得tensor
a_tensor = graph.get_tensor_by_name("Variable:0")
b_tensor = graph.get_tensor_by_name("Mul:0")
c_tensor = graph.get_tensor_by_name("Add:0")
print(sess.run(a_tensor)) # 变量的初始值2
print(sess.run(b_tensor)) # 首轮计算的数值6
print(sess.run(c_tensor)) # 首轮计算的数值8
载入计算图和数据
import tensorflow as tf
# 载入模型图,不包含数据
# 注意路径指向为save时保存生成的.meta文件
saver = tf.train.import_meta_graph("d:/temp/model.meta")
with tf.Session() as sess:
# 载入模型数据
# 注意下方数据来源于同一模型,但是并非在一个save函数中保存,结果正确
# 最好用同一个save生成的模型数据"d:/temp/model"
saver.restore(sess, "d:/temp/untitled/loop.ckpt-46")
# 取得计算图
graph = tf.get_default_graph()
# 取得tensor
a_tensor = graph.get_tensor_by_name("Variable:0")
b_tensor = graph.get_tensor_by_name("Mul:0")
c_tensor = graph.get_tensor_by_name("Add:0")
print(sess.run(a_tensor))
print(sess.run(b_tensor))
print(sess.run(c_tensor))
固化
在模型完成训练后,用于实际计算应用场景下时,可以将变量转化为常量,成为计算图的一部分。一方面方便保存与载入,另一方面提高运算效率。
固化保存代码如下:
import tensorflow as tf
x = tf.placeholder(dtype=tf.float32, shape=[]) # 定义输入
print(x) # Tensor("Placeholder:0", shape=(), dtype=float32)
a = tf.Variable(2, dtype=tf.float32)
print(a) # <tf.Variable 'Variable:0' shape=() dtype=float32_ref>
b = tf.multiply(a, 3)
print(b) # Tensor("Mul:0", shape=(), dtype=float32)
c = tf.add(a, b)
print(c) # Tensor("Add:0", shape=(), dtype=float32)
y = tf.subtract(x, c)
print(y) # Tensor("Sub:0", shape=(), dtype=float32)
with tf.Session() as sess:
# 变量初始化
sess.run(tf.global_variables_initializer())
# 循环更新变量数值,模拟训练的过程
for i in range(50):
print(sess.run(c))
sess.run(tf.assign_add(a, 1.5))
# 生成当前计算图的GraphDef
graph_def = tf.get_default_graph().as_graph_def()
# 将当前图中变量全部转为常量
# 注意,output_node_names传入输出节点名称列表,注意节点名称与tensor名称的区别
graph_def_output = tf.graph_util.convert_variables_to_constants(sess=sess, input_graph_def=graph_def, output_node_names=["Placeholder", "Sub", "Variable"])
# 下方代码生成保存文件constant_model,大多数情况下会加上.pb扩展名表明类型
with tf.gfile.FastGFile("d:/temp/constant_model", "wb") as f:
f.write(graph_def_output.SerializeToString())
固化载入代码如下:
import tensorflow as tf
with tf.Session() as sess:
with tf.gfile.FastGFile("d:/temp/constant_model", "rb") as f:
# 建立GraphDef,用于导入计算图
graph_def = tf.GraphDef()
# 从文件内容解析GraphDef内容
graph_def.ParseFromString(f.read())
# 从GraphDef加载计算图
[x_tensor, y_tensor, a_tensor] = tf.import_graph_def(graph_def, return_elements=["Placeholder:0", "Sub:0", "Variable:0"])
print(sess.run(a_tensor))
print(sess.run(y_tensor, feed_dict={x_tensor: 1000}))
注意
关于CNN实际计算时去除训练中的batch维度:
在载入CNN网络模型时,如果用于实际计算,而不是用于训练,则取出计算输出的tensor,在run()
函数中feed_dict
参数输入的placeholder,batch维度的数值设置为1,即输入tensor的shape为(1, 行数目, 列数目, 通道数目)
因为计算输出的tensor的计算中不包含需要从batch维度取值的操作。