Embedding计算服务设计
基本设计
Embedding 的计算过程使用Python进行开发,提供HTTP接口供其他服务调用,主要包括以下功能:
- Embedding 层的计算:接收用户输入的特征,计算对应的Embedding向量。
- Embedding 层的更新:接收用户输入的特征和Embedding向量,更新Embedding层的参数。
- Embedding 层的查询:接收用户输入的特征,查询对应的Embedding向量。
在实际执行的过程中,使用ONNX Runtime作为深度学习推理引擎,将Embedding层的计算放到GPU/CPU上,提高推理速度。
使用 FastAPI 作为 Web 框架,支持异步请求处理,适合用于构建高性能的 Web 服务。
具体实现
该部分主要依赖了以下库:
FastAPI
:用于构建 Web 服务。PyTorch
:用于构建深度学习模型。ONNX Runtime
:用于部署深度学习模型。Uvicorn
:用于部署 FastAPI 应用。
初始化部分,主要包括加载模型和初始化 ONNX Runtime:
# Initializing
status = 0
print("Initializing now")
def show_mask(mask, ax):
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
# NOTE: Param - onnx model
# I think I'll use this model as default
onnx_model_path = 'interactive_module_quantized_592547_2023_03_19_sam6_long_uncertain.onnx'
ort_session = onnxruntime.InferenceSession(onnx_model_path)
# If the user use the same onnx model, no need to load ort again.
current_onnx_name = onnx_model_path
predictor = SamPredictor(sam)
# Ready
status = 1
print("Ready now.")
接着定义了一个异步的接口,用于计算Embedding向量:
@app.post("/embedding")
async def get_embedding(request: EmbeddingRequest):
global status
if status==0 or status==2:
print("Busy or not ready now")
raise HTTPException(status_code=400, detail=f"Busy now")
# Busy
status=2
print("Working now")
img_url = request.img_url
try:
response = requests.get(img_url)
response.raise_for_status()
except requests.RequestException as e:
raise HTTPException(status_code=400, detail=f"Error downloading image: {e}")
img_name = os.path.basename(img_url)
img_path = Path(img_name)
with open(img_path, 'wb') as img_file:
img_file.write(response.content)
# Now use cv2 to load the image
image = cv2.imread(img_name)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
global current_onnx_name
global predictor
if current_onnx_name != request.onnx_name:
# Then you have to reload ort
onnx_model_path = request.onnx_name
ort_session = onnxruntime.InferenceSession(onnx_model_path)
current_onnx_name = request.onnx_name
predictor = SamPredictor(sam)
predictor.set_image(image)
image_embedding = predictor.get_image_embedding().cpu().numpy()
# So now save the numpy tensor
timestamp = int(time.time())
filename = f'embedding_{timestamp}.npy'
np.save(filename, image_embedding)
res = upload(filename)
# Free
status=1
print("Free")
if res != 500:
return {"msg": res}
else:
return {"msg": failed}
最后启动服务:
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=4536)