Service 脚本
"""
NLG ckpt/pd model server
"""
from tornado.options import define, options
import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
import json
import tensorflow as tf
from tensorflow.compat.v1.train import NewCheckpointReader
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
#from create_tf_record import *
from tensorflow.python.framework import graph_util
from hparams_dh import Hparams
from data_utils import get_hypotheses
import os
from data_preprocess import Dataset
import logging
logging.basicConfig(level=logging.INFO)
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
dataset = Dataset(hp.maxlen1, hp.maxlen2, hp.vocab, hp.batch_size)
class ckpt_predict:
def __init__(self, pb_path):
self.pb_path = pb_path
def predict(self, keys):
'''
:param pb_path:pb文件的路径
:param image_path:测试图片的路径
:return:
'''
datas, _ = dataset.get_batch(keys)
logging.info(datas.shape)
# 加载模型定义的graph
#saver = tf.train.import_meta_graph(ckpt_path + '.meta')
with tf.Graph().as_default():
output_graph_def = tf.compat.v1.GraphDef()
with open(self.pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.compat.v1.Session() as sess:
inputx = sess.graph.get_tensor_by_name("inputx:0")
output_tensor_name = sess.graph.get_tensor_by_name("output:0")
predict = sess.run(output_tensor_name, feed_dict={inputx: datas})
hypotheses = get_hypotheses(predict, dataset.idx2token)
res = self.postprocessor(hypotheses)
return {"titles": res}
def postprocessor(self, data):
print(data)
logging.info(type(data))
return data
class IndexHandler(tornado.web.RequestHandler):
def get(self):
json_byte = self.request.body
input = json.loads(json_byte)
keys = input.get('keys', '')
titlesMap = cp.predict(keys)
self.write(titlesMap)
def post(self):
json_byte = self.request.body # json读取
input = json.loads(json_byte)
keys = input.get('keys', '')
titlesMap = cp.predict(keys)
# self.write(dumps(displayQuery2ScoreMap))
self.write(titlesMap)
define("port", default=8080, help="run on the given port", type=int) # 端口
if __name__ == '__main__':
pb_path = os.getcwd() + "/pb_model/tansformer.pb"
#ckpt_path = "/home/public_readonly/linliping/transformer-master2/model/iwslt2016_E10L8.25-145980"
keys = ['tower-t13001-coffee-anti-drip-feature', 'rainier-gear-digital-vision-binocular', 'zeiss-compact-pocket-grey-black-binocular', 'barska-blackhawk-18-36x50-waterproof-spotting', 'celestron-52331-trailseeker-65-straight', 'baselay-night-vision-glasses-driving']
cp = ckpt_predict(pb_path)
"""
res = cp.predict(keys)
print(res)
"""
logging.info('server is started ……')
tornado.options.parse_command_line()
app = tornado.web.Application(handlers=[(r"/nlg/titles", IndexHandler)])
http_server = tornado.httpserver.HTTPServer(app)
http_server.listen(options.port)
tornado.ioloop.IOLoop.instance().start()
Client 脚本(客服端)
import json
import requests
def getResult(keys):
url = "http://172.19.80.37:8080/nlg/titles"
#headers = {'Content-Type': 'application/json'}
data = {"keys": keys}
res = requests.post(url, data=json.dumps(data))
return res.text
if __name__ == '__main__':
keys = ['tower-t13001-coffee-anti-drip-feature', 'rainier-gear-digital-vision-binocular', 'zeiss-compact-pocket-grey-black-binocular', 'barska-blackhawk-18-36x50-waterproof-spotting', 'celestron-52331-trailseeker-65-straight', 'baselay-night-vision-glasses-driving']
res = getResult(keys)
print(res)