使用自定义大模型来部署Wren AI(开源的文本生成SQL方案)

使用自定义大模型来部署Wren AI(开源的文本生成SQL方案)

关于

  • 首次发表日期:2024-07-15
  • Wren AI官方文档: https://docs.getwren.ai/overview/introduction
  • Wren AI Github仓库: https://github.com/Canner/WrenAI

关于Wren AI

Wren AI 是一个开源的文本生成SQL解决方案。

前提准备

由于之后会使用docker来启动服务,所以首先确保docker已经安装好了,并且网络没问题。

先克隆仓库:

git clone https://github.com/Canner/WrenAI.git

关于在Wren AI中使用自定义大模型和Embedding模型

Wren AI目前是支持自定义LLM和Embedding模型的,其官方文档 https://docs.getwren.ai/installation/custom_llm 中有提及,需要创建自己的provider类。

其中Wren AI本身已经支持和OPEN AI兼容的大模型了;但是自定义的Embedding模型方面,可能会报错,具体来说是wren-ai-service/src/providers/embedder/openai.py中的以下代码

if self.dimensions is not None:
    response = await self.client.embeddings.create(
        model=self.model, dimensions=self.dimensions, input=text_to_embed
    )
else:
    response = await self.client.embeddings.create(
        model=self.model, input=text_to_embed
    )

其中if self.dimensions is not None这个条件分支是会报错的(默认会运行这个分支),所以我的临时解决方案是注释掉它。

具体而言是在wren-ai-service/src/providers/embedder文件夹中创建一个openai_like.py文件,表示定义一个和open ai类似的embedding provider,取个名字叫做openai_like_embedder,具体的完整代码见本文附录。

配置docker环境变量等并启动服务

首先,进入docker文件夹,拷贝.env.example并重命名为.env.local

然后拷贝.env.ai.example并重命名为.env.ai,修改其中的LLM和Embedding的配置,相关部分如下:

LLM_PROVIDER=openai_llm
LLM_OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxx
LLM_OPENAI_API_BASE=http://api.siliconflow.cn/v1
GENERATION_MODEL=meta-llama/Meta-Llama-3-70B
# GENERATION_MODEL_KWARGS={"temperature": 0, "n": 1, "max_tokens": 32768, "response_format": {"type": "json_object"}}

EMBEDDER_PROVIDER=openai_like_embedder
EMBEDDING_MODEL=bge-m3
EMBEDDING_MODEL_DIMENSION=1024
EMBEDDER_OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxx
EMBEDDER_OPENAI_API_BASE=https://xxxxxxxxxxxxxxxx/v1

由于我们创建了一个自定义的embedding provider,需要将文件映射到docker容器中,具体可以通过配置docker-compose.yaml中的wren-ai-service,添加volumes属性:

wren-ai-service:
  image: ghcr.io/canner/wren-ai-service:${WREN_AI_SERVICE_VERSION}
  volumes:
    - /root/WrenAI/wren-ai-service/src:/src

最后,启动服务:

docker-compose -f docker-compose.yaml -f docker-compose.llm.yaml --env-file .env.local --env-file .env.ai up -d

或者停止服务:

docker-compose -f docker-compose.yaml -f docker-compose.llm.yaml --env-file .env.local --env-file .env.ai down

附录

openai_like.py文件(提供自定义embedding服务):

import logging
import os
from typing import Any, Dict, List, Optional, Tuple

import backoff
import openai
from haystack import Document, component
from haystack.components.embedders import OpenAIDocumentEmbedder, OpenAITextEmbedder
from haystack.utils import Secret
from openai import AsyncOpenAI, OpenAI
from tqdm import tqdm

from src.core.provider import EmbedderProvider
from src.providers.loader import provider

import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

logger = logging.getLogger("wren-ai-service")

EMBEDDER_OPENAI_API_BASE = "https://api.openai.com/v1"
EMBEDDING_MODEL = "text-embedding-3-large"
EMBEDDING_MODEL_DIMENSION = 3072


@component
class AsyncTextEmbedder(OpenAITextEmbedder):
    def __init__(
        self,
        api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),
        model: str = "text-embedding-ada-002",
        dimensions: Optional[int] = None,
        api_base_url: Optional[str] = None,
        organization: Optional[str] = None,
        prefix: str = "",
        suffix: str = "",
    ):
        super(AsyncTextEmbedder, self).__init__(
            api_key,
            model,
            dimensions,
            api_base_url,
            organization,
            prefix,
            suffix,
        )
        self.client = AsyncOpenAI(
            api_key=api_key.resolve_value(),
            organization=organization,
            base_url=api_base_url,
        )

    @component.output_types(embedding=List[float], meta=Dict[str, Any])
    @backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)
    async def run(self, text: str):
        if not isinstance(text, str):
            raise TypeError(
                "OpenAITextEmbedder expects a string as an input."
                "In case you want to embed a list of Documents, please use the OpenAIDocumentEmbedder."
            )

        logger.debug(f"Running Async OpenAI text embedder with text: {text}")

        text_to_embed = self.prefix + text + self.suffix

        # copied from OpenAI embedding_utils (https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py)
        # replace newlines, which can negatively affect performance.
        text_to_embed = text_to_embed.replace("\n", " ")

        # if self.dimensions is not None:
        #     response = await self.client.embeddings.create(
        #         model=self.model, dimensions=self.dimensions, input=text_to_embed
        #     )
        # else:
        response = await self.client.embeddings.create(
            model=self.model, input=text_to_embed
        )

        meta = {"model": response.model, "usage": dict(response.usage)}

        return {"embedding": response.data[0].embedding, "meta": meta}


@component
class AsyncDocumentEmbedder(OpenAIDocumentEmbedder):
    def __init__(
        self,
        api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),
        model: str = "text-embedding-ada-002",
        dimensions: Optional[int] = None,
        api_base_url: Optional[str] = None,
        organization: Optional[str] = None,
        prefix: str = "",
        suffix: str = "",
        batch_size: int = 32,
        progress_bar: bool = True,
        meta_fields_to_embed: Optional[List[str]] = None,
        embedding_separator: str = "\n",
    ):
        super(AsyncDocumentEmbedder, self).__init__(
            api_key,
            model,
            dimensions,
            api_base_url,
            organization,
            prefix,
            suffix,
            batch_size,
            progress_bar,
            meta_fields_to_embed,
            embedding_separator,
        )
        self.client = AsyncOpenAI(
            api_key=api_key.resolve_value(),
            organization=organization,
            base_url=api_base_url,
        )

    async def _embed_batch(
        self, texts_to_embed: List[str], batch_size: int
    ) -> Tuple[List[List[float]], Dict[str, Any]]:
        all_embeddings = []
        meta: Dict[str, Any] = {}
        for i in tqdm(
            range(0, len(texts_to_embed), batch_size),
            disable=not self.progress_bar,
            desc="Calculating embeddings",
        ):
            batch = texts_to_embed[i : i + batch_size]
            # if self.dimensions is not None:
            #     response = await self.client.embeddings.create(
            #         model=self.model, dimensions=self.dimensions, input=batch
            #     )
            # else:
            response = await self.client.embeddings.create(
                model=self.model, input=batch
            )
            embeddings = [el.embedding for el in response.data]
            all_embeddings.extend(embeddings)

            if "model" not in meta:
                meta["model"] = response.model
            if "usage" not in meta:
                meta["usage"] = dict(response.usage)
            else:
                meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
                meta["usage"]["total_tokens"] += response.usage.total_tokens

        return all_embeddings, meta

    @component.output_types(documents=List[Document], meta=Dict[str, Any])
    @backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)
    async def run(self, documents: List[Document]):
        if (
            not isinstance(documents, list)
            or documents
            and not isinstance(documents[0], Document)
        ):
            raise TypeError(
                "OpenAIDocumentEmbedder expects a list of Documents as input."
                "In case you want to embed a string, please use the OpenAITextEmbedder."
            )

        logger.debug(
            f"Running Async OpenAI document embedder with documents: {documents}"
        )

        texts_to_embed = self._prepare_texts_to_embed(documents=documents)

        embeddings, meta = await self._embed_batch(
            texts_to_embed=texts_to_embed, batch_size=self.batch_size
        )

        for doc, emb in zip(documents, embeddings):
            doc.embedding = emb

        return {"documents": documents, "meta": meta}


@provider("openai_like_embedder")
class OpenAIEmbedderProvider(EmbedderProvider):
    def __init__(
        self,
        api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),
        api_base: str = os.getenv("EMBEDDER_OPENAI_API_BASE")
        or EMBEDDER_OPENAI_API_BASE,
        embedding_model: str = os.getenv("EMBEDDING_MODEL") or EMBEDDING_MODEL,
        embedding_model_dim: int = (
            int(os.getenv("EMBEDDING_MODEL_DIMENSION"))
            if os.getenv("EMBEDDING_MODEL_DIMENSION")
            else 0
        )
        or EMBEDDING_MODEL_DIMENSION,
    ):
        def _verify_api_key(api_key: str, api_base: str) -> None:
            """
            this is a temporary solution to verify that the required environment variables are set
            """
            OpenAI(api_key=api_key, base_url=api_base).models.list()

        logger.info(f"Initializing OpenAIEmbedder provider with API base: {api_base}")
        # TODO: currently only OpenAI api key can be verified
        if api_base == EMBEDDER_OPENAI_API_BASE:
            _verify_api_key(api_key.resolve_value(), api_base)
            logger.info(f"Using OpenAI Embedding Model: {embedding_model}")
        else:
            logger.info(
                f"Using OpenAI API-compatible Embedding Model: {embedding_model}"
            )
        self._api_key = api_key
        self._api_base = api_base
        self._embedding_model = embedding_model
        self._embedding_model_dim = embedding_model_dim

    def get_text_embedder(self):
        return AsyncTextEmbedder(
            api_key=self._api_key,
            api_base_url=self._api_base,
            model=self._embedding_model,
            dimensions=self._embedding_model_dim,
        )

    def get_document_embedder(self):
        return AsyncDocumentEmbedder(
            api_key=self._api_key,
            api_base_url=self._api_base,
            model=self._embedding_model,
            dimensions=self._embedding_model_dim,
        )
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值