一、参考资料
TensorFlow模型的保存与加载(二)——pb模式【源码】
深度学习 TensorFlow中模型的freeze_graph
TensorFlow 保存模型为 PB 文件
二、相关介绍
MetaGraph
MetaGraphDef
是MetaGraph
的 Protocol Buffer
表示。
MetaGraph 包括计算图(网络结构),数据流图,以及相关的变量和输入输出的signature。
protocol buffer
(pb)
pb 是 MetaGraph 的 protocol buffer(pb )文件格式。
Frozen GraphDef
格式
Frozen GraphDef
将冻结(Frozen)TensorFlow导出的模型权重参数,使得其都变为常量。并且模型权重参数和网络结构保存在同一个文件中,可以在python以及java中自由调用。
Frozen GraphDef
格式,属于冻结(Frozen)的GraphDef文件,这种文件格式不包含Variables节点,而是将GraphDef中所有Variables固定为常量(其值从checkpoint获取)。Frozen GraphDef
格式将计算图和权重以常量的形式保存在一张静态图(pb)中。
When you are saving your graph, a MetaGraph is created. This is the graph itself, and all the other metadata necessary for computations in this graph, as well as some user info that can be saved and version specification.
三、关键步骤
1. 定义计算图
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
with tf.Session() as sess:
# 定义计算图结构 x*y+b
x = tf.placeholder(tf.int32, name='x_input')
y = tf.placeholder(tf.int32, name='y_input')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
output = tf.add(xy, b, name='output') # 指定输出节点的名称name
# 输出计算图结果
sess.run(tf.global_variables_initializer())
y_pred = sess.run(output, feed_dict={x: 10, y: 3})
print(y_pred) # 输出31
2. 保存pb模型
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.framework.graph_util import convert_variables_to_constants
import os
def save_pb_model(sess, save_path, output_nodes):
"""
convert_variables_to_constants 需要指定输出节点的名称
参数1:session会话
参数2:计算图的graph_def对象
参数3:输出节点的名称,list类型,可以多个
"""
# output_graph_def = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), output_nodes)
output_graph_def = convert_variables_to_constants(sess,
input_graph_def=tf.get_default_graph().as_graph_def(),
output_node_names=output_nodes)
# 检查是否存在路径
path = os.path.abspath(save_path) # 获取绝对路径
if os.path.exists(path) is False:
os.makedirs(path)
print("成功创建模型保存新路径:{}".format(path))
# 将计算图写入序列化的pb文件
with tf.gfile.GFile(save_path + "frozen_graph.pb", mode="wb") as f:
f.write(output_graph_def.SerializeToString())
print("成功使用PB模式保存模型到路径:{}".format(path))
with tf.Session() as sess:
# 定义计算图结构 x*y+b
x = tf.placeholder(tf.int32, name='x_input')
y = tf.placeholder(tf.int32, name='y_input')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
output = tf.add(xy, b, name='output') # 指定输出节点的名称name
sess.run(tf.global_variables_initializer()) # 初始化过程
# 保存pb模型
save_pb_model(sess, './models/', ['output'])
3. 加载pb模型
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
def load_pb_model(sess, save_path):
"""read pb into graph_def"""
with tf.gfile.FastGFile(save_path + 'frozen_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
"""import graph_def"""
sess.graph.as_default()
tf.import_graph_def(graph_def, name='') # 导入计算图
with tf.Session() as sess:
# 加载pb模型
load_pb_model(sess, './models/')
# 获取计算图节点
graph = tf.get_default_graph() # 获取计算图
x = graph.get_tensor_by_name("x_input:0")
y = graph.get_tensor_by_name("y_input:0")
output = graph.get_tensor_by_name("output:0")
sess.run(tf.global_variables_initializer()) # 初始化过程
# 运行计算图计算结果
y_pred = sess.run(output, feed_dict={x: 10, y: 3})
print(y_pred) # 输出 31
4. 完整代码
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os
def save_pb_model(sess, save_path, output_nodes):
"""
convert_variables_to_constants 需要指定输出节点的名称
参数1:session会话
参数2:计算图的graph_def对象
参数3:输出节点的名称,list类型,可以多个
"""
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(),
output_nodes)
# 检查是否存在路径
path = os.path.abspath(save_path) # 获取绝对路径
if os.path.exists(path) is False:
os.makedirs(path)
print("成功创建模型保存新路径:{}".format(path))
# 将计算图写入序列化的pb文件
with tf.gfile.GFile(save_path + "frozen_graph.pb", mode="wb") as f:
f.write(output_graph_def.SerializeToString())
print("成功使用PB模式保存模型到路径:{}".format(path))
def load_pb_model(sess, save_path):
with tf.gfile.FastGFile(save_path + 'frozen_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='') # 导入计算图
with tf.Session() as sess:
# 定义计算图结构 x*y+b
x = tf.placeholder(tf.int32, name='x_input')
y = tf.placeholder(tf.int32, name='y_input')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
output = tf.add(xy, b, name='output') # 指定输出节点的名称name
sess.run(tf.global_variables_initializer()) # 初始化过程
# 输出计算图结果
y_pred = sess.run(output, feed_dict={x: 10, y: 3})
print(y_pred) # 输出31
# 保存pb模型
save_pb_model(sess, './models/', ['output'])
# 加载pb模型
load_pb_model(sess, './models/')
# 直接获取保存的变量
print(sess.run('b:0'))
# 获取计算图节点
graph = tf.get_default_graph() # 获取计算图
x = graph.get_tensor_by_name("x_input:0")
y = graph.get_tensor_by_name("y_input:0")
output = graph.get_tensor_by_name("output:0")
# 运行计算图计算结果
y_pred = sess.run(output, feed_dict={x: 10, y: 3})
print(y_pred) # 输出 31