object_detection_webapi
使用flask和tensorflow搭建object detection web api服务器
本文旨在记录编写程序中所学习到的知识和注意事项。
TensorFlow
TensorFlow Models
TensorFlow Models包里包含一些用Tf实现的模型,分为四个部分
- official models
- research models
- samples folder
- tutorials folder
安装一般使用git将repository下载到本地后,将整个models文件夹放在了tensorflow的文件夹下面。Tensorflow目录可通过import tensorflow as tf; import os; os.path.split(tf.__file__)[0]
得到
我们要用到的object detection在research models里
Tensorflow Object Detection
Tensorflow Object Detection API是研究人员维护的目标检测包,在一个图片中定位和识别多个目标。
Tensorflow Object Detection API安装步骤
Tensorflow detection model zoo里是一些可以开箱使用的具有不同速度和精度的模型,通常基于网上的有名开源数据集训练得到
本程序基于但又高于Quick Start: Jupyter notebook for off-the-shelf inference
TensorFlow模型文件
checkpoint(*.ckpt)
通常以来依赖于创建模型的代码,其保存训练过程中的变量,如权重等。
模型的计算图也可以从GraphDef(*.pb)恢复。
frozen graph是指将模型中的变量转换为常量后保存的GraphDef文件,通常使用convert_variable_to_constants函数。
也可将使用freeze_graph.py将checkpoint文件和GraphDef文件转换为FrozenGraphDef文件
SavedModel
一个文件夹,GraphDef和CheckPoint的结合体,另外还有标记模型输入和输出参数的SignatureDef。
WEB API
HTTP
Hyper Text Transfer Protocol,超文本传输协议,是一种建立在TCP上的无状态连接,整个基本的工作流程是客户端发送一个HTTP请求,说明客户端想要访问的资源和请求的动作,服务端收到请求之后,服务端开始处理请求,并根据请求做出相应的动作访问服务器资源,最后通过发送HTTP响应把结果返回给客户端。
HTTP协议是文本协议。
在HTTP上传输二进制数据一般会使用base64编码
Base64算法、Base64Encode、UrlEcode编码及应用
WEB API
SOAP webserivce 和 RESTful webservice 对比及区别
REST是一种轻量级的Web Service架构风格,其实现和操作比SOAP和XML-RPC更为简洁,可以完全通过HTTP协议实现
RESTful API是目前最流行的 API 设计规范,用于 Web 数据接口的设计。
个人理解的WEB API就是客户端使用HTTP的各种方法(GET, POST, etc)通过HTTP协议向服务器提出请求,服务器处理后通常返回json数据包
WEB框架
web服务器(nginx, apache, etc)和web框架(flask, django, etc)是不同的。
服务器和客户端之间的连接靠web服务器来维持,web服务器接收到请求后,将请求以及相关的参数传递给web框架,由框架负责生成内容,并将生成的内容传递给web服务器。所以web服务器的职责是接受并返回请求,web框架的职责是内容生成
Flask
flask使用python修饰器设置URL路由
Flask的目标检测函数处理流程
flask接受到请求后,获取请求中的base64编码的图像数据,将其转化为numpy.ndarray(具体转化流程:base64编码字符串——》二进制数据——》BytesIO——》PIL.Image——》numpy.ndarray)。
然后输入到之前加载的模型中预测推理,得到原始结果(矩阵形式)后,转化为结构化结果,并根据需求来决定是否进一步产生可视化结果(使用object detection包里的visualize_boxes_and_labels_on_image_array函数)及后续转化(numpy.ndarray——》PIL.Image——》(可选择保存为JPEG格式)BytesIO——》二进制数据——》base64编码字符串)
flask将结构化结果序列化json格式并返回客户端。
总结
感觉在实现的过程中,花费了不少时间,学到挺多,接触到了很多新的领域,不过到最后总结整理时反而都不知道自己新学了什么了,这就很是尴尬
花了很多时间的一个原因是,有个先看事情概貌,消除未知的未知,然后在钻研部分的习惯,导致看了不少一时半会用不到的。
附录
附录的代码是我从jupyter notebook里之间导出的python文件内容
server
#!/usr/bin/env python
# coding: utf-8
# Object Detection WEB API Server
# ==
# 欢迎使用Object Detection WEB API Server。
#
# 该文件会启动webapi服务器,使用官方收集的预训练模型,接受客户端通过POST上传的图片并检测其中的目标,以json的格式返回预测结果。
#
# - 确保从[TensorFlow Models
# ](https://github.com/tensorflow/modelsd)拉取Tensorflow Models放置在Tensorflow目录下(我估计不用下全,只下载Object Detection应该也可以)。Tensorflow目录可通过```import tensorflow as tf; import os; os.path.split(tf.__file__)[0]```得到
#
# - 按照[Tensorflow Object Detection API安装步骤](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md)安装Tensorflow Object Detection API
#
# - 安装flask
# ```pip install flask```
#
# - 安装PIL
# ```pip install Pillow```
#
# # Imports
# In[1]:
import numpy as np
import tensorflow as tf
import io, os, sys
from distutils.version import StrictVersion
from collections import defaultdict
from PIL import Image
# if StrictVersion(tf.__version__) < StrictVersion('1.9.0'):
# raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')
# In[2]:
TF_PATH = os.path.split(tf.__file__)[0]
TF_MR_PATH = os.path.join(TF_PATH, 'models', 'research')
sys.path.append(TF_MR_PATH)
TF_SLIM_PATH = os.path.join(TF_MR_PATH, 'slim')
sys.path.append(TF_SLIM_PATH)
TF_OD_PATH = os.path.join(TF_MR_PATH, 'object_detection')
from tensorflow.models.research.object_detection.utils import label_map_util
from tensorflow.models.research.object_detection.utils import ops as utils_ops
from tensorflow.models.research.object_detection.utils import visualization_utils as vis_util
# In[3]:
from flask import Flask, request, jsonify
import base64
import uuid
# # Variables may be need changed
# In[4]:
# List of the strings that is used to add correct label for each box.
LABELS_PATH = ""
if LABELS_PATH == "":
LABELS_PATH = os.path.join(TF_OD_PATH, 'data', 'mscoco_label_map.pbtxt')
# If use model in models zoo, just let FROZEN_GRAPH_PATH == "", else change it.
FROZEN_GRAPH_PATH = ""
# # Model preparation
# ## Variables
#
# 任何使用`export_inference_graph.py`工具导出的模型,都可以通过修改`PATH_TO_FROZEN_GRAPH`变量指向新的.pb文件(模型)来加载
#
# 我们默认使用"SSD with Mobilenet"模型,可以访问[detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)查看其他可以具有不同速度和精度的开箱使用模型
# In[5]:
if FROZEN_GRAPH_PATH == "":
# What model to download.
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILENAME = MODEL_NAME + '.tar.gz'
MODEL_PATH = os.path.join(TF_OD_PATH, MODEL_FILENAME)
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
# Path to frozen detection graph. This is the actual model that is used for the object detection.
FROZEN_GRAPH_PATH = os.path.join(TF_OD_PATH, MODEL_NAME, 'frozen_inference_graph.pb')
if not os.path.exists(FROZEN_GRAPH_PATH):
if not os.path.exists(MODEL_PATH):
import six.moves.urllib as urllib
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILENAME, MODEL_PATH)
import tarfile
tar_file = tarfile.open(MODEL_PATH)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, TF_OD_PATH)
# ## Load a (frozen) Tensorflow model into memory.
# In[6]:
def load_model(PATH_TO_FROZEN_GRAPH, graph=None):
if graph is None:
graph = tf.Graph()
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return graph
# ## Loading label map
# Label maps map indices to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`. Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine
# In[7]:
category_index = label_map_util.create_category_index_from_labelmap(LABELS_PATH, use_display_name=True)
# # Detection function
# 对单个图片进行推理预测,并返回模型预测结果
# In[8]:
def run_inference_for_single_image(image, graph):
with graph.as_default():
with tf.Session() as sess:
# Get handles to input and output tensors
ops = tf.get_default_graph().get_operations()
all_tensor_names = {output.name for op in ops for output in op.outputs}
tensor_dict = {}
for key in [
'num_detections', 'detection_boxes', 'detection_scores',
'detection_classes', 'detection_masks'
]:
tensor_name = key + ':0'
if tensor_name in all_tensor_names:
tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
tensor_name)
if 'detection_masks' in tensor_dict:
# The following processing is only for single image
detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
# Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
detection_masks, detection_boxes, image.shape[0], image.shape[1])
detection_masks_reframed = tf.cast(
tf.greater(detection_masks_reframed, 0.5), tf.uint8)
# Follow the convention by adding back the batch dimension
tensor_dict['detection_masks'] = tf.expand_dims(
detection_masks_reframed, 0)
image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
# Run inference
output_dict = sess.run(tensor_dict,
feed_dict={image_tensor: np.expand_dims(image, 0)})
# all outputs are float32 numpy arrays, so convert types as appropriate
output_dict['num_detections'] = int(output_dict['num_detections'][0])
output_dict['detection_classes'] = output_dict[
'detection_classes'][0].astype(np.uint8)
output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
output_dict['detection_scores'] = output_dict['detection_scores'][0]
if 'detection_masks' in output_dict:
output_dict['detection_masks'] = output_dict['detection_masks'][0]
return output_dict
# 将模型预测结果转化为更加易于理解和传输的结构化格式
# In[ ]:
def convert_to_structure_format(
boxes,
classes,
scores,
category_index,
image_shape=None,
use_normalized_coordinates=True,
max_boxes=None,
min_score_thresh=0.5
):
"""
Args:
boxes: a numpy array of shape [N, 4]
classes: a numpy array of shape [N]. Note that class indices are 1-based,
and match the keys in the label map.
scores: a numpy array of shape [N] or None.
category_index: a dict containing category dictionaries (each holding
category index `id` and category name `name`) keyed by category indices.
use_normalized_coordinates: whether boxes is to be interpreted as
normalized coordinates or not.
max_boxes: maximum number of boxes. If None, convert all boxes.
min_score_thresh: minimum score threshold for a box to be convert
Returns:
dict which contain number of objects and a list about object's name, box, score, etc
"""
if not max_boxes:
max_boxes = boxes.shape[0]
# image_shape-->(h, w) in Numpy, image_size-->(w, h) in PIL
if use_normalized_coordinates:
# im_width, im_height = image_size
im_height, im_width = image_shape
else:
im_height, im_width = (1, 1)
objects = [{"name":category_index[classes[i]]['name'] if classes[i] in category_index.keys() else "",
"bndbox":{"xmin":int(boxes[i][1]*im_width), "ymin":int(boxes[i][0]*im_height),
"xmax":int(boxes[i][3]*im_width), "ymax":int(boxes[i][2]*im_height)},
"score":float(scores[i])}
for i in range(min(max_boxes, boxes.shape[0]))
if scores[i] is None or scores[i] > min_score_thresh]
# return {"number":len(objects), "objects":objects}
return objects
# print(result)
# 对nparray格式的图片推理预测,并返回结构化结果和可视化效果(可选)
# In[ ]:
def predict_on_image_np(image_np, graph, visualization=False):
image_np_expanded = np.expand_dims(image_np, axis=0)
output_dict = run_inference_for_single_image(image_np, graph)
# print(image_np.shape)
result_list = convert_to_structure_format(
output_dict['detection_boxes'],
output_dict['detection_classes'],
output_dict['detection_scores'],
category_index,
image_np.shape[-3:-1],
use_normalized_coordinates=True)
if visualization:
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
output_dict['detection_boxes'],
output_dict['detection_classes'],
output_dict['detection_scores'],
category_index,
instance_masks=output_dict.get('detection_masks'),
use_normalized_coordinates=True,
max_boxes_to_draw=None,
line_thickness=8)
else:
image_np = None
return result_list, image_np
# # Flask setting
# 设置flask服务器
#
# ## API说明
# 检测图片中的目标的位置、种类、可信度
# ### 请求说明
# - HTTP方法:`POST`
# - 请求URL:`/api/object_detect/predict`
# - Header:
#
# |参数|值|
# |:---|:--|
# Content-Type|application/x-www-form-urlencoded
#
# - Body:
#
# |参数|是否必选|类型|可选范围|说明|
# |:---|:----|:---|:----|:---|
# image|true|string|-|图像数据,base64编码
# image_visual|false|boolean|-|是否返回可视化结果。默认false
#
# ### 返回说明
#
# 返回结果是json格式
#
# #### 返回参数
#
# |参数|是否必选|类型|说明|
# |:-|:-|:-|:-|
# |log_id|是|UUID|唯一的log id,用于问题定位|
# |result|是|list|预测结果|
# |+bndbox|是|字典|box信息|
# |++xmax|是|int|box右下角的水平坐标|
# |++xmin|是|int|box左上角的水平坐标|
# |++ymax|是|int|box右下角的垂直坐标|
# |++ymin|是|int|box左上角的垂直坐标|
# |+name|是|string|目标类别|
# |+score|是|float|评分,可以理解为置信度|
# |result_num|是|int|检测出目标数目|
# |success|是|boolean|是否成功预测|
# |image_visual|否|string|base64编码的可视化结果|
#
# #### 返回示例
# ```
# {'log_id': 'c4808689-f3f1-4d01-907c-2fa662626f8b',
# 'result': [{'bndbox': {'xmax': 323, 'xmin': 19, 'ymax': 554, 'ymin': 24},
# 'name': 'dog',
# 'score': 0.9406907558441162},
# {'bndbox': {'xmax': 996, 'xmin': 412, 'ymax': 588, 'ymin': 69},
# 'name': 'dog',
# 'score': 0.9345026612281799}],
# 'result_num': 2,
# 'success': True}
# ```
# In[ ]:
app = Flask(__name__)
URL_PRED = "/api/object_detect/predict"
@app.route("/")
def homepage():
return "Welcome to the object detection REST API!\nPlease use " + URL_PRED
@app.route(URL_PRED, methods=["GET", "POST"])
def predict():
if request.method == "POST":
res_dict = {'success':False}
visualization = True if request.form.get("visual",default='').lower()=='true' else False
# print(type(visualization))
# print(visualization)
if request.form.get("image"):
image_b64 = request.form.get("image")
# base64.b64decode:Decode the Base64 encoded bytes-like object or ASCII string s.
image_b64decode = base64.b64decode(image_b64)
imageIO = io.BytesIO(image_b64decode)
image = Image.open(imageIO)
image_np = np.asarray(image).copy() if visualization else np.asarray(image)
res_dict['result'], res_image_np = predict_on_image_np(image_np, detection_graph, visualization)
res_dict['result_num'] = len(res_dict['result'])
if visualization:
res_image = Image.fromarray(res_image_np)
img_buffer = io.BytesIO()
res_image.save(img_buffer, format='JPEG')
# b64encode:Encode the bytes-like object s using Base64 and return a bytes object.
res_image_b64 = base64.b64encode(img_buffer.getvalue())
res_dict['image_visual'] = res_image_b64.decode('utf-8')
res_dict['success'] = True
res_dict['log_id'] = str(uuid.uuid4())
# print(type(res_dict['image_visual'])) if visualization else None
# print('res_dict:\t' + str(res_dict))
return jsonify(res_dict)
elif request.method == "GET":
return "Please use POST method."
else:
return "Please POST a image."
# # Start web api server
# In[ ]:
if __name__ == '__main__':
print("Loading model...")
detection_graph = load_model(FROZEN_GRAPH_PATH)
print("Starting web api server...")
# app.run()
# Set this to ``'0.0.0.0'`` to have the server available externally as well.
app.run('0.0.0.0')
# In[ ]: