1 模型转pb
# coding=utf-8
from tensorflow.contrib.saved_model.python.saved_model.utils import simple_save
from options import Options
from data_provider import *
from collections import Counter, defaultdict
import operator
import os
import numpy as np
import tensorflow as tf
from model import Model
from options import Options
opt = Options()
sess_config = tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True
print("Building model...")
model = Model(opt)
model.is_train = False
with tf.Session(config=sess_config) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(opt.save_path))
model.is_train = False
print('Model successfully loaded')
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
simple_save(sess, "./model",
inputs={
"x_": model.x_,
"x_length": model.x_length,
"dropout": model.dropout},
outputs={
"Softmax_1": model.scores[0],
"Softmax_3": model.scores[1],
"Softmax_5": model.scores[2],
"Softmax_7": model.scores[3],
"Softmax_9": model.y_score
},
legacy_init_op=legacy_init_op
)
2 pb文件python调用
2.1 模型输入输出tensor名查询
saved_model_cli show --dir model/ --tag_set serve --signature_def serving_default
2.2 调用
import tensorflow as tf
import numpy as np
saved_model_dir = "model"
with tf.Session() as sess1:
# load pb
meta_graph_def = tf.saved_model.loader.load(sess1, [tf.saved_model.tag_constants.SERVING], saved_model_dir)
# get 输入输出
input = sess1.graph.get_tensor_by_name('Test/Model/inputs:0')
target = sess1.graph.get_tensor_by_name('Test/Model/targets:0')
cost = sess1.graph.get_tensor_by_name('Test/Model/truediv:0')
inputsINT = sess1.graph.get_tensor_by_name('Test/Model/hash_table_Lookup:0')
targetsINT = sess1.graph.get_tensor_by_name('Test/Model/hash_table_Lookup_1:0')
#嗯,社保缴费
feed_dict = {input: [["嗯", ",", "社", "保", "缴", "费"]],
target: [[",", "社", "保", "缴", "费", "<eos>"]]}
y=sess1.run([cost,inputsINT,targetsINT], feed_dict=feed_dict)
print(y)
print(y[0])
print(np.exp(y[0]/6))