Sapphire项目日志(二十一)

PyTorch推理服务和FastAPI构建使

用 FastAPI 和 PyTorch 构建高效的推理服务

在现代应用程序中,构建高效的推理服务对于处理大规模数据和实时请求至关重要。本文将介绍如何结合 PyTorch 和 FastAPI 来构建一个快速、可扩展的服务端推理服务,以便接收和处理推理请求。

1. 什么是 PyTorch 推理服务?

PyTorch 是一个流行的深度学习框架,广泛用于训练神经网络模型。在许多实际应用中,需要将训练好的模型部署到生产环境中,以便进行推理任务。PyTorch 推理服务即是将训练好的模型部署到服务器上,以便接收输入数据并生成预测结果。

2. FastAPI 简介

FastAPI 是一个现代的 Python Web 框架,旨在提供高性能和易用性。它基于 Python 类型提示功能,能够自动生成 API 文档,并支持异步请求处理,使其成为构建高效 Web 服务的理想选择。

3. 结合 PyTorch 和 FastAPI

要构建一个服务端 PyTorch 推理服务,并使用 FastAPI 来接收推理请求,需要按照以下步骤进行:

  • 加载 PyTorch 模型:首先,加载训练好的 PyTorch 模型,该模型将用于进行推理任务。

  • 创建 FastAPI 应用:使用 FastAPI 创建一个 Web 应用,定义路由和请求处理函数。

  • 编写推理函数:编写一个函数,该函数将接收输入数据,将其传递给 PyTorch 模型进行推理,并返回预测结果。

  • 定义 API 路由:在 FastAPI 应用中定义一个 POST 路由,用于接收推理请求,并调用推理函数进行处理。

4. 示例代码

### Param: image file, onnx model 

import datetime
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel

import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

def show_mask(mask, ax):
    color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))   

checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)

onnx_model_path = 'interactive_module_quantized_592547_2023_03_19_sam6_long_uncertain.onnx'
ort_session = onnxruntime.InferenceSession(onnx_model_path)

current_onnx_name = onnx_model_path

predictor = SamPredictor(sam)

import re
import requests
def upload(file_path):
    # 省略

# Now SAM is ready and should start an HTTP server to fetch images.
import time
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from pathlib import Path

app = FastAPI()

class EmbeddingRequest(BaseModel):
    onnx_name: str
    img_url: str

@app.post("/embedding")
async def get_embedding(request: EmbeddingRequest):
    img_url = request.img_url
    try:
        response = requests.get(img_url)
        response.raise_for_status()
    except requests.RequestException as e:
        raise HTTPException(status_code=400, detail=f"Error downloading image: {e}")
    img_name = os.path.basename(img_url)
    img_path = Path(img_name)
    with open(img_path, 'wb') as img_file:
        img_file.write(response.content)

    # Now use cv2 to load the image
    image = cv2.imread(img_name)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    global current_onnx_name
    global predictor

    if current_onnx_name != request.onnx_name:
        # Then you have to reload ort
        onnx_model_path = request.onnx_name
        ort_session = onnxruntime.InferenceSession(onnx_model_path)
        current_onnx_name = request.onnx_name
        predictor = SamPredictor(sam)

    predictor.set_image(image)
    image_embedding = predictor.get_image_embedding().cpu().numpy()

    # So now save the numpy tensor
    timestamp = int(time.time())
    filename = f'embedding_{timestamp}.npy'
    np.save(filename, image_embedding)
    res = upload(filename)
    if res != 500:
        return {"msg": res}
    else:
        return {"msg": failed}


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=4536)
  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值