tf SavedModel 保存模型的新方式

原文链接: tf SavedModel 保存模型的新方式

上一篇: python pycharm 远程调试

下一篇: tf SavedModel 转换为 可使用 tfjs 加载 的形式

参考

https://www.tensorflow.org/guide/saved_model#build_and_load_a_savedmodel

比pb和ckpt复杂而且麻烦...

只是为了使模型在tfjs上跑所以必须使用这种方式, 及其麻烦....

简单保存和加载

保存

import tensorflow as tf

a = tf.Variable(2., name='a')
b = tf.Variable(3., name='b')
in_x = tf.placeholder(tf.float32)
export_path = './pb/v10'
out = in_x * a + b
out = tf.identity(out, 'output')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    tf.saved_model.simple_save(
        sess,
        export_path,
        inputs={"in_x": in_x},
        outputs={"output": out}
    )

加载

import tensorflow as tf
from tensorflow.python.saved_model import tag_constants

export_dir = "./pb/v10"
image_content = []

with tf.Session() as sess:
    meta_graph_def = tf.saved_model.loader.load(sess, [tag_constants.SERVING], export_dir)
    print(meta_graph_def)
    signature = meta_graph_def.signature_def

    print(signature)
    x_tensor_name = signature['serving_default'].inputs["in_x"].name
    y_tensor_name = signature["serving_default"].outputs["output"].name

    x = sess.graph.get_tensor_by_name(x_tensor_name)
    y = sess.graph.get_tensor_by_name(y_tensor_name)
    print(x, y)
    y_out = sess.run(y, feed_dict={x: 1})
    print(y_out)

复杂自定义

保存

import tensorflow as tf

a = tf.Variable(2., name='a')
b = tf.Variable(3., name='b')
in_x = tf.placeholder(tf.float32)
export_path = './pb/v12'
out = in_x * a + b
out = tf.identity(out, 'output')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)

    # Build the signature_def_map.
    classification_inputs = tf.saved_model.utils.build_tensor_info(
        in_x)
    classification_outputs_classes = tf.saved_model.utils.build_tensor_info(
        out)

    classification_signature = (
        tf.saved_model.signature_def_utils.build_signature_def(
            inputs={
                "in_x": classification_inputs
            },
            outputs={
                "output": classification_outputs_classes,
            },
            # method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME
        )
    )

    # tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    # tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

    # prediction_signature = (
    #     tf.saved_model.signature_def_utils.build_signature_def(
    #         inputs={'images': tensor_info_x},
    #         outputs={'scores': tensor_info_y},
    #         method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
    # )

    builder.add_meta_graph_and_variables(
        sess, [tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            # 'predict_images':                prediction_signature,
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                classification_signature,
        },
        main_op=tf.tables_initializer(),
        strip_default_attrs=True)

    builder.save()

加载

import tensorflow as tf
from tensorflow.python.saved_model import tag_constants

export_dir = "./pb/v8"
image_content = []

with tf.Session() as sess:
    meta_graph_def = tf.saved_model.loader.load(sess, [tag_constants.SERVING], export_dir)
    print(meta_graph_def)
    signature = meta_graph_def.signature_def

    print(signature)
    x_tensor_name = signature['serving_default'].inputs["in_x"].name
    y_tensor_name = signature["serving_default"].outputs["output"].name

    x = sess.graph.get_tensor_by_name(x_tensor_name)
    y = sess.graph.get_tensor_by_name(y_tensor_name)
    print(x, y)
    y_out = sess.run(y, feed_dict={x: 1})
    print(y_out)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值