一、flask安装
flask是一个轻量级基于Python的web框架
安装执行pip install Flask
二、代码实例
1、服务端代码
以上一篇文章深度学习onnx图像分类推理为例,通过post请求输入一张base64编码的图像到服务端并且返回识别结果的类别和概率,将需要识别的部分写成方法调用,如模型加载等初始化内容放在方法外面,只需要启动时加载一次。
代码如下:
# myweb.py
import cv2
import numpy as np
import base64
import onnxruntime as ort
import argparse
import flask
from PIL import Image
app = flask.Flask(__name__)
def softmax(x):
x = x.reshape(-1)
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
def postprocess(result):
"""
执行softmax
"""
return softmax(np.array(result)).tolist()
def recognition(img):
"""
识别模块
img:输入待识别图像
"""
img = Image.fromarray(img)
img = img.resize((128, 128), 0)
img = np.asarray(img, np.float32)/255.0
img = img[np.newaxis, np.newaxis, :, :]
input_blob = np.array(img, dtype=np.float32)
onnx_result = ort_session.run([onnx_outputs_names], input_feed={onnx_input_name: input_blob})
res = postprocess(onnx_result)
idx = np.argmax(res)
return int(idx), res[idx]
@app.route('/mytest', methods=["POST"]) # 假设请求方法为post请求,路由名写为mytest
def work():
request = flask.request
# 获取请求IP地址
if request.headers.getlist("X-Forwarded-For"):
ip = request.headers.getlist("X-Forwarded-For")[-1]
else:
ip = request.remote_addr
print(">>>>>>>>ip:{}<<<<<<<<<<".format(ip))
returnData = {}
params = request.json # 接收json类型参数
img = params["img"]
img = base64.b64decode(img)
img = np.frombuffer(img, np.uint8)
img = cv2.imdecode(img, cv2.IMREAD_COLOR)
if len(img.shape) == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
result, prob = recognition(img)
returnData["result"] = result
returnData["prob"] = prob
return returnData # 返回识别结果
if __name__ == "__main__":
# recognition函数里面的东西提出来避免每次调用重复加载
onnx_model_path = "./test.onnx"
ort_session = ort.InferenceSession(onnx_model_path)
onnx_input_name = ort_session.get_inputs()[0].name
onnx_outputs_names = ort_session.get_outputs()[0].name
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--host",
default="127.0.0.1",
required=False,
help="host")
parser.add_argument("--port", default=8091, required=False, help="port")
args = parser.parse_args()
# 127.0.0.1本机调用,端口为8091
app.run(host=args.host,
port=int(args.port),
debug=False,
use_reloader=False)
执行myweb.py代码,开放一个8091端口供调用
2、客户端代码
代码如下:
# post.py
import base64
import requests
# 加载一张图像并转为base64
with open('./1.png', 'rb+') as f:
data_base64 = base64.b64encode(f.read())
data_base64 = data_base64.decode()
input_args = {"img": data_base64} # 传入base64图像
res = requests.post('http://127.0.0.1:8091/mytest', json=input_args) # 请求接口获取结果
print(res.text)
执行post.py得到服务端的返回结果
三、总结
本文简单的demo了一下以服务的形式做图像识别,recognition方法也可以替换成其他需要处理的内容,当然上线使用这点还不够,日志记录异常处理这些都没加,下次写一下flask+gunicorn实现高并发,并使用supervisor来监听服务状态。