tensorflow中sparse_placeholder在saved_model中保存pb模型的使用方法

# coding:utf-8

# @author: liu
# @file: sparse_tensor_pb.py
# @time: 2020/4/3 11:00
# @desc:

# coding:utf-8


import tensorflow as tf
import random
import numpy as np
import os
import shutil


print(tf.__version__)

def create_sparse(batch_size, dtype=np.int32):
    '''
    创建稀疏张量,ctc_loss中labels要求是稀疏张量,随机生成序列长度在150~180之间的labels
    '''
    indices = []
    values = []
    for i in range(batch_size):
        length = random.randint(150, 180)
        for j in range(length):
            indices.append((i, j))
            value = random.randint(0, 779)
            values.append(value)

    indices = np.asarray(indices, dtype=np.int64)
    values = np.asarray(values, dtype=dtype)
    shape = np.asarray([batch_size, np.asarray(indices).max(0)[1] + 1], dtype=np.int64)  # [64,180]

    return [indices, values, shape]

# 保存成pb模型

def saved_model(sess: tf.Session, input: tf.sparse_placeholder , ss, cc, model_path):
    if os.path.exists(model_path):
        shutil.rmtree(model_path)
    builder = tf.saved_model.builder.SavedModelBuilder(model_path)

    # input_x = tf.saved_model.build_tensor_info(input)

    indices = tf.saved_model.build_tensor_info(input.indices)
    values = tf.saved_model.build_tensor_info(input.values)
    dense_shape = tf.saved_model.build_tensor_info(input.dense_shape)

    output_a = tf.saved_model.build_tensor_info(ss)
    output_b = tf.saved_model.build_tensor_info(cc)

    prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(inputs={"indices": indices, "values": values, "dense_shapes": dense_shape},
                                                                                  outputs={"ss": output_a, "cc": output_b},
                                                                                  method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

    builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={"predict": prediction_signature})
    builder.save()

def load_model(model_path=None):

    sess = tf.Session()

    meta_graph = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)
    signature = meta_graph.signature_def
    # input_x = signature["predict"].inputs["input"].name
    indices = signature["predict"].inputs["indices"].name
    values = signature["predict"].inputs["values"].name
    dense_shape = signature["predict"].inputs["dense_shapes"].name



    output_a = signature["predict"].outputs["ss"].name
    output_b = signature["predict"].outputs["cc"].name
    # return sess, input_x, output_a, output_b
    return sess, indices, values, dense_shape, output_a, output_b


def train():
    a = tf.sparse_placeholder(tf.float32, name="in")
    values = a.values
    c = tf.sparse_to_dense(a.indices, a.dense_shape, a.values)

    s = tf.sparse.reduce_sum(a)

    indices_list = [[0, 1], [0, 4], [2, 3]]
    # print(indices_list)
    values_list = [1, 4, 5]
    dense_shape_list = [4, 5]

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        ss, cc, va = sess.run([s, c, values], feed_dict={
            a: tf.SparseTensorValue(indices=indices_list, values=values_list, dense_shape=dense_shape_list)})
        print(ss)
        print("-----------")
        print(cc)
        # print("va", va)
        saved_model(sess, a, s, c, "model/1")


def predict():
    indices_list = [[0, 1], [0, 4], [2, 3]]
    # print(indices_list)
    values_list = [1, 4, 5]
    dense_shape_list = [4, 5]

    # sess, input_x, output_a, output_b = load_model("model/1")
    sess, indices, values, dense_shape, output_a, output_b = load_model("model/1")

    with sess:
        # a, b = sess.run([output_a, output_b], {input_x: tf.SparseTensorValue(indices=indices_list, values=values_list, dense_shape=dense_shape_list)})
        a, b = sess.run([output_a, output_b], feed_dict={indices:indices_list, values:values_list, dense_shape:dense_shape_list})
        print("预测结果")
        print(a)
        print(b)



train()
predict()

















  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值