参考githup地址:https://github.com/aaxwaz/Serving-TensorFlow-Model
启动服务:
bazel build //tensorflow_serving/model_servers:tensorflow_model_server
nohup bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=example_model --model_base_path=/Users/shuubiasahi/Desktop/modelserving > log.log &
Flask服务代码启动:
"""This script wraps the client into a Flask server. It receives POST request with
prediction data, and forward the data to tensorflow server for inference.
"""
from flask import Flask, render_template, request, url_for, jsonify,Response
import json
import tensorflow as tf
import numpy as np
import os
import argparse
import sys
from datetime import datetime
from grpc.beta import implementations
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
tf.app.flags.DEFINE_string('server', 'localhost:9000', 'PredictionService host:port')
FLAGS = tf.app.flags.FLAGS
app = Flask(__name__)
class mainSessRunning():
def __init__(self):
host, port = FLAGS.server.split(':')
channel = implementations.insecure_channel(host, int(port))
self.stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
self.request = predict_pb2.PredictRequest()
self.request.model_spec.name = 'example_model'
self.request.model_spec.signature_name = 'prediction'
def inference(self, val_x):
# temp_data = numpy.random.randn(100, 3).astype(numpy.float32)
temp_data = val_x.astype(np.float32).reshape(-1, 3)
print("temp_data is:", temp_data)
data, label = temp_data, np.sum(temp_data * np.array([1, 2, 3]).astype(np.float32), 1)
self.request.inputs['input'].CopyFrom(
tf.contrib.util.make_tensor_proto(data, shape=data.shape))
result = self.stub.Predict(self.request, 5.0)
return result, label
run = mainSessRunning()
print("Initialization done. ")
# Define a route for the default URL, which loads the form
@app.route('/inference', methods=['POST'])
def inference():
request_data = request.json
input_data = np.expand_dims(np.array(request_data), 0)
result, label = run.inference(input_data)
di={"result":str(result),'label': label[0].tolist()}
return Response(json.dumps(di), mimetype='application/json')
#return jsonify(di)
@app.route('/test', methods=['GET'])
def test_serv():
return ("Hello")
if __name__ == "__main__":
app.run()
flask客户端调用:
import requests, json
data =[1,0.5,2.0]
headers = {'content-type': 'application/json'}
r = requests.post("http://127.0.0.1:5000/inference", data=json.dumps(data),
headers=headers)
print(r.json() )
结果:
{'result': 'outputs {\n key: "output"\n value {\n dtype: DT_FLOAT\n tensor_shape {\n dim {\n size: 1\n }\n dim {\n size: 1\n }\n }\n float_val: 7.69464254379\n }\n}\n', 'label': 8.0}