使用 FastAPI 和 PyTorch 构建高效的推理服务
在现代应用程序中,构建高效的推理服务对于处理大规模数据和实时请求至关重要。本文将介绍如何结合 PyTorch 和 FastAPI 来构建一个快速、可扩展的服务端推理服务,以便接收和处理推理请求。
1. 什么是 PyTorch 推理服务?
PyTorch 是一个流行的深度学习框架,广泛用于训练神经网络模型。在许多实际应用中,需要将训练好的模型部署到生产环境中,以便进行推理任务。PyTorch 推理服务即是将训练好的模型部署到服务器上,以便接收输入数据并生成预测结果。
2. FastAPI 简介
FastAPI 是一个现代的 Python Web 框架,旨在提供高性能和易用性。它基于 Python 类型提示功能,能够自动生成 API 文档,并支持异步请求处理,使其成为构建高效 Web 服务的理想选择。
3. 结合 PyTorch 和 FastAPI
要构建一个服务端 PyTorch 推理服务,并使用 FastAPI 来接收推理请求,需要按照以下步骤进行:
-
加载 PyTorch 模型:首先,加载训练好的 PyTorch 模型,该模型将用于进行推理任务。
-
创建 FastAPI 应用:使用 FastAPI 创建一个 Web 应用,定义路由和请求处理函数。
-
编写推理函数:编写一个函数,该函数将接收输入数据,将其传递给 PyTorch 模型进行推理,并返回预测结果。
-
定义 API 路由:在 FastAPI 应用中定义一个 POST 路由,用于接收推理请求,并调用推理函数进行处理。
4. 示例代码
### Param: image file, onnx model
import datetime
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel
import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic
def show_mask(mask, ax):
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
onnx_model_path = 'interactive_module_quantized_592547_2023_03_19_sam6_long_uncertain.onnx'
ort_session = onnxruntime.InferenceSession(onnx_model_path)
current_onnx_name = onnx_model_path
predictor = SamPredictor(sam)
import re
import requests
def upload(file_path):
# 省略
# Now SAM is ready and should start an HTTP server to fetch images.
import time
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from pathlib import Path
app = FastAPI()
class EmbeddingRequest(BaseModel):
onnx_name: str
img_url: str
@app.post("/embedding")
async def get_embedding(request: EmbeddingRequest):
img_url = request.img_url
try:
response = requests.get(img_url)
response.raise_for_status()
except requests.RequestException as e:
raise HTTPException(status_code=400, detail=f"Error downloading image: {e}")
img_name = os.path.basename(img_url)
img_path = Path(img_name)
with open(img_path, 'wb') as img_file:
img_file.write(response.content)
# Now use cv2 to load the image
image = cv2.imread(img_name)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
global current_onnx_name
global predictor
if current_onnx_name != request.onnx_name:
# Then you have to reload ort
onnx_model_path = request.onnx_name
ort_session = onnxruntime.InferenceSession(onnx_model_path)
current_onnx_name = request.onnx_name
predictor = SamPredictor(sam)
predictor.set_image(image)
image_embedding = predictor.get_image_embedding().cpu().numpy()
# So now save the numpy tensor
timestamp = int(time.time())
filename = f'embedding_{timestamp}.npy'
np.save(filename, image_embedding)
res = upload(filename)
if res != 500:
return {"msg": res}
else:
return {"msg": failed}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=4536)