Sapphire项目日志(二十二)

Embedding计算服务设计

基本设计

Embedding 的计算过程使用Python进行开发,提供HTTP接口供其他服务调用,主要包括以下功能:

  • Embedding 层的计算:接收用户输入的特征,计算对应的Embedding向量。
  • Embedding 层的更新:接收用户输入的特征和Embedding向量,更新Embedding层的参数。
  • Embedding 层的查询:接收用户输入的特征,查询对应的Embedding向量。

在实际执行的过程中,使用ONNX Runtime作为深度学习推理引擎,将Embedding层的计算放到GPU/CPU上,提高推理速度。

使用 FastAPI 作为 Web 框架,支持异步请求处理,适合用于构建高性能的 Web 服务。

具体实现

该部分主要依赖了以下库:

  • FastAPI:用于构建 Web 服务。
  • PyTorch:用于构建深度学习模型。
  • ONNX Runtime:用于部署深度学习模型。
  • Uvicorn:用于部署 FastAPI 应用。

初始化部分,主要包括加载模型和初始化 ONNX Runtime:

# Initializing
status = 0
print("Initializing now")

def show_mask(mask, ax):
    color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))   

checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)

# NOTE: Param - onnx model
# I think I'll use this model as default
onnx_model_path = 'interactive_module_quantized_592547_2023_03_19_sam6_long_uncertain.onnx'
ort_session = onnxruntime.InferenceSession(onnx_model_path)

# If the user use the same onnx model, no need to load ort again.
current_onnx_name = onnx_model_path

predictor = SamPredictor(sam)

# Ready
status = 1
print("Ready now.")

接着定义了一个异步的接口,用于计算Embedding向量:

@app.post("/embedding")
async def get_embedding(request: EmbeddingRequest):
    global status
    if status==0 or status==2:
        print("Busy or not ready now")
        raise HTTPException(status_code=400, detail=f"Busy now")
    # Busy
    status=2
    print("Working now")
    
    img_url = request.img_url
    try:
        response = requests.get(img_url)
        response.raise_for_status()
    except requests.RequestException as e:
        raise HTTPException(status_code=400, detail=f"Error downloading image: {e}")
    img_name = os.path.basename(img_url)
    img_path = Path(img_name)
    with open(img_path, 'wb') as img_file:
        img_file.write(response.content)

    # Now use cv2 to load the image
    image = cv2.imread(img_name)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    global current_onnx_name
    global predictor

    if current_onnx_name != request.onnx_name:
        # Then you have to reload ort
        onnx_model_path = request.onnx_name
        ort_session = onnxruntime.InferenceSession(onnx_model_path)
        current_onnx_name = request.onnx_name
        predictor = SamPredictor(sam)

    predictor.set_image(image)
    image_embedding = predictor.get_image_embedding().cpu().numpy()

    # So now save the numpy tensor
    timestamp = int(time.time())
    filename = f'embedding_{timestamp}.npy'
    np.save(filename, image_embedding)
    res = upload(filename)
    # Free
    status=1
    print("Free")
    if res != 500:
        return {"msg": res}
    else:
        return {"msg": failed}

最后启动服务:

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=4536)
  • 6
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值