原文链接: tf SavedModel 保存模型的新方式
上一篇: python pycharm 远程调试
下一篇: tf SavedModel 转换为 可使用 tfjs 加载 的形式
参考
https://www.tensorflow.org/guide/saved_model#build_and_load_a_savedmodel
比pb和ckpt复杂而且麻烦...
只是为了使模型在tfjs上跑所以必须使用这种方式, 及其麻烦....
简单保存和加载
保存
import tensorflow as tf
a = tf.Variable(2., name='a')
b = tf.Variable(3., name='b')
in_x = tf.placeholder(tf.float32)
export_path = './pb/v10'
out = in_x * a + b
out = tf.identity(out, 'output')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.saved_model.simple_save(
sess,
export_path,
inputs={"in_x": in_x},
outputs={"output": out}
)
加载
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
export_dir = "./pb/v10"
image_content = []
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess, [tag_constants.SERVING], export_dir)
print(meta_graph_def)
signature = meta_graph_def.signature_def
print(signature)
x_tensor_name = signature['serving_default'].inputs["in_x"].name
y_tensor_name = signature["serving_default"].outputs["output"].name
x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)
print(x, y)
y_out = sess.run(y, feed_dict={x: 1})
print(y_out)
复杂自定义
保存
import tensorflow as tf
a = tf.Variable(2., name='a')
b = tf.Variable(3., name='b')
in_x = tf.placeholder(tf.float32)
export_path = './pb/v12'
out = in_x * a + b
out = tf.identity(out, 'output')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
# Build the signature_def_map.
classification_inputs = tf.saved_model.utils.build_tensor_info(
in_x)
classification_outputs_classes = tf.saved_model.utils.build_tensor_info(
out)
classification_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={
"in_x": classification_inputs
},
outputs={
"output": classification_outputs_classes,
},
# method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME
)
)
# tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
# tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
# prediction_signature = (
# tf.saved_model.signature_def_utils.build_signature_def(
# inputs={'images': tensor_info_x},
# outputs={'scores': tensor_info_y},
# method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
# )
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
# 'predict_images': prediction_signature,
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
},
main_op=tf.tables_initializer(),
strip_default_attrs=True)
builder.save()
加载
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
export_dir = "./pb/v8"
image_content = []
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess, [tag_constants.SERVING], export_dir)
print(meta_graph_def)
signature = meta_graph_def.signature_def
print(signature)
x_tensor_name = signature['serving_default'].inputs["in_x"].name
y_tensor_name = signature["serving_default"].outputs["output"].name
x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)
print(x, y)
y_out = sess.run(y, feed_dict={x: 1})
print(y_out)