tf pb模型 tornado 服务部署

Service 脚本

"""
NLG ckpt/pd model server
"""
from tornado.options import define, options
import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
import json
import tensorflow as tf
from tensorflow.compat.v1.train import NewCheckpointReader
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
#from create_tf_record import *
from tensorflow.python.framework import graph_util
from hparams_dh import Hparams
from data_utils import get_hypotheses
import os
from data_preprocess import Dataset
import logging

logging.basicConfig(level=logging.INFO)
os.environ["CUDA_VISIBLE_DEVICES"] = "2"



hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
dataset = Dataset(hp.maxlen1, hp.maxlen2, hp.vocab, hp.batch_size)

class ckpt_predict:
    def __init__(self, pb_path):
        self.pb_path = pb_path
    

    def predict(self, keys):
        '''
        :param pb_path:pb文件的路径
        :param image_path:测试图片的路径
        :return:
        '''
        datas, _ = dataset.get_batch(keys)
        logging.info(datas.shape)
        # 加载模型定义的graph
        #saver = tf.train.import_meta_graph(ckpt_path + '.meta')
        with tf.Graph().as_default():
            output_graph_def = tf.compat.v1.GraphDef()
            with open(self.pb_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                tf.import_graph_def(output_graph_def, name="")
            with tf.compat.v1.Session() as sess:
                inputx = sess.graph.get_tensor_by_name("inputx:0")

                output_tensor_name = sess.graph.get_tensor_by_name("output:0")

                predict = sess.run(output_tensor_name, feed_dict={inputx: datas})
                hypotheses = get_hypotheses(predict, dataset.idx2token)
                res = self.postprocessor(hypotheses)
        
        return {"titles": res}


    def postprocessor(self, data):
        print(data)
        logging.info(type(data))
        return data


class IndexHandler(tornado.web.RequestHandler):

    def get(self):
        json_byte = self.request.body
        input = json.loads(json_byte)
        keys = input.get('keys', '')
        titlesMap = cp.predict(keys)
        self.write(titlesMap)

    def post(self):
        json_byte = self.request.body   # json读取
        input = json.loads(json_byte)
        keys = input.get('keys', '')
        titlesMap = cp.predict(keys)

        # self.write(dumps(displayQuery2ScoreMap))
        self.write(titlesMap)


define("port", default=8080, help="run on the given port", type=int)    # 端口


if __name__ == '__main__':

    pb_path = os.getcwd() + "/pb_model/tansformer.pb"
    #ckpt_path = "/home/public_readonly/linliping/transformer-master2/model/iwslt2016_E10L8.25-145980"
    
    keys = ['tower-t13001-coffee-anti-drip-feature', 'rainier-gear-digital-vision-binocular', 'zeiss-compact-pocket-grey-black-binocular', 'barska-blackhawk-18-36x50-waterproof-spotting', 'celestron-52331-trailseeker-65-straight', 'baselay-night-vision-glasses-driving']
    
    cp = ckpt_predict(pb_path)
    
    """
    res = cp.predict(keys)
    print(res)
    """

    logging.info('server is started ……')
    tornado.options.parse_command_line()
    app = tornado.web.Application(handlers=[(r"/nlg/titles", IndexHandler)])
    http_server = tornado.httpserver.HTTPServer(app)
    http_server.listen(options.port)
    tornado.ioloop.IOLoop.instance().start()




Client 脚本(客服端)

import json
import requests

def getResult(keys):

    url = "http://172.19.80.37:8080/nlg/titles"
    #headers = {'Content-Type': 'application/json'}
    data = {"keys": keys}

    res = requests.post(url, data=json.dumps(data))
    
    return res.text


if __name__ == '__main__':

    keys = ['tower-t13001-coffee-anti-drip-feature', 'rainier-gear-digital-vision-binocular', 'zeiss-compact-pocket-grey-black-binocular', 'barska-blackhawk-18-36x50-waterproof-spotting', 'celestron-52331-trailseeker-65-straight', 'baselay-night-vision-glasses-driving']
    
    res = getResult(keys)
    print(res)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值