折腾了我几天,一直搞不定。最后用以下代码成功保存。
方法一:
tensorflow2.0以上版本可以使用
tf.saved_model.save(model, "save_test")
model = tf.saved_model.load("save_test")
来保存成pb文件以及读取,但是保存的是将模型和权重独立。
2020.3.1更新:
下面方法为新的保存方法,可以直接将模型和权重保存为pb文件。
2020.5.6更新:
保存成pb模型必须在程序最开始处调用:
tf.enable_eager_execution()
使其进入eager模式。
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="",
name="frozen_graph.pb",
as_text=False)
with tf.Graph().as_default():
output_graph_def = tf.compat.v1.GraphDef()
# 打开.pb模型
with open("frozen_graph.pb", "rb") as f:
output_graph_def.ParseFromString(f.read())
tensors = tf.import_graph_def(output_graph_def, name='')
# print("tensors:", tensors)
with tf.compat.v1.Session() as sess:
op = sess.graph.get_operations()
for i, m in enumerate(op):
print('op{}:'.format(i), m.values())
input_x = sess.graph.get_tensor_by_name("x:0") #可以看op的首末名input.name