vanna学习日志(五)


前言

vanna可实现自然语言转SQL,尝试本地部署vanna对接数据库,将自然语言转成标准的SQL对数据库进行查询。本文先对vanna源码进行分析,该部分内容为接入各类ai方法。

一、vanna源码分析

pinecone_vector

PineconeDB_VectorStore 类用于使用 Pinecone 向量数据库存储和检索向量数据。这是通过与 Pinecone 客户端交互来实现的。

我们将 PineconeDB_VectorStore 代码分段进行详细注释和解析。

导入和类定义部分

import json  # 导入 JSON 模块,用于处理 JSON 数据
from typing import List  # 导入 List 类型提示

from pinecone import Pinecone, PodSpec, ServerlessSpec  # 从 pinecone 库导入 Pinecone 客户端和规格类
import pandas as pd  # 导入 pandas 模块,用于数据处理
from ..base import VannaBase  # 从本地模块导入 VannaBase 基础类
from ..utils import deterministic_uuid  # 从本地模块导入生成确定性 UUID 的函数

from fastembed import TextEmbedding  # 从 fastembed 模块导入 TextEmbedding 类
  • 导入了所需的模块和库,包括 JSON 处理、类型提示、Pinecone 客户端和规格、数据处理工具 pandas,以及一些本地模块。
class PineconeDB_VectorStore(VannaBase):
    """
    使用 PineconeDB 的向量存储类

    Args:
        config (dict): 配置字典。必须提供 Pinecone 客户端或 API 密钥。
    Raises:
        ValueError: 如果配置无效,抛出错误。
    """
  • 定义 PineconeDB_VectorStore 类,继承自 VannaBase,并在类文档中说明需要传递的配置参数和可能的异常。

初始化方法

    def __init__(self, config=None):
        # 调用父类 VannaBase 的构造函数
        VannaBase.__init__(self, config=config)
        if config is None:
            raise ValueError("需要提供配置,传递 Pinecone 客户端或 API 密钥。")
        
        client = config.get("client")  # 获取配置中的客户端
        api_key = config.get("api_key")  # 获取配置中的 API 密钥
        
        if not api_key and not client:
            raise ValueError("需要在配置中提供 api_key 或传递已配置的客户端")
        
        if not client and api_key:
            self._client = Pinecone(api_key=api_key)  # 使用 API 密钥初始化 Pinecone 客户端
        elif not isinstance(client, Pinecone):
            raise ValueError("client 必须是 Pinecone 的实例")
        else:
            self._client = client  # 使用传递的 Pinecone 客户端

        # 设置其他配置参数
        self.n_results = config.get("n_results", 10)
        self.dimensions = config.get("dimensions", 384)
        self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
        self.documentation_namespace = config.get("documentation_namespace", "documentation")
        self.distance_metric = config.get("distance_metric", "cosine")
        self.ddl_namespace = config.get("ddl_namespace", "ddl")
        self.sql_namespace = config.get("sql_namespace", "sql")
        self.index_name = config.get("index_name", "vanna-index")
        self.metadata_config = config.get("metadata_config", {})
        self.server_type = config.get("server_type", "serverless")
        
        if self.server_type not in ["serverless", "pod"]:
            raise ValueError("server_type 必须是 'serverless' 或 'pod'")

        # 配置 Pod 规格
        self.podspec = config.get(
            "podspec",
            PodSpec(environment="us-west-2", pod_type="p1.x1", metadata_config=self.metadata_config),
        )
        # 配置无服务器规格
        self.serverless_spec = config.get("serverless_spec", ServerlessSpec(cloud="aws", region="us-west-2"))
        self._setup_index()  # 设置索引
  • 初始化方法中,调用了父类的构造函数。
  • 检查配置是否为空,如果为空,则抛出 ValueError
  • 根据配置初始化 Pinecone 客户端,确保传递的客户端是 Pinecone 的实例。
  • 设置其他相关配置参数,如返回结果数、向量维度、模型名称、命名空间、距离度量等。
  • 检查 server_type 是否为有效值(serverlesspod)。
  • 配置 Pod 和无服务器规格,并调用 _setup_index 方法进行索引设置。

索引设置和检查方法

    def _set_index_host(self, host: str) -> None:
        self.Index = self._client.Index(host=host)  # 设置索引主机

    def _setup_index(self) -> None:
        existing_indexes = self._get_indexes()  # 获取现有索引
        
        if self.index_name not in existing_indexes and self.server_type == "serverless":
            self._client.create_index(
                name=self.index_name,
                dimension=self.dimensions,
                metric=self.distance_metric,
                spec=self.serverless_spec,
            )
            pinecone_index_host = self._client.describe_index(self.index_name)["host"]
            self._set_index_host(pinecone_index_host)
        elif self.index_name not in existing_indexes and self.server_type == "pod":
            self._client.create_index(
                name=self.index_name,
                dimension=self.dimensions,
                metric=self.distance_metric,
                spec=self.podspec,
            )
            pinecone_index_host = self._client.describe_index(self.index_name)["host"]
            self._set_index_host(pinecone_index_host)
        else:
            pinecone_index_host = self._client.describe_index(self.index_name)["host"]
            self._set_index_host(pinecone_index_host)

    def _get_indexes(self) -> list:
        return [index["name"] for index in self._client.list_indexes()]  # 列出所有索引的名称

    def _check_if_embedding_exists(self, id: str, namespace: str) -> bool:
        fetch_response = self.Index.fetch(ids=[id], namespace=namespace)  # 从指定命名空间获取嵌入
        return fetch_response["vectors"] != {}  # 如果嵌入存在,返回 True
  • _set_index_host 方法设置索引主机。
  • _setup_index 方法根据配置创建或获取索引,如果索引不存在则创建索引,并设置索引主机。
  • _get_indexes 方法列出所有现有索引的名称。
  • _check_if_embedding_exists 方法检查嵌入是否已经存在于指定的命名空间中。

添加和查询数据的方法

    def add_ddl(self, ddl: str, **kwargs) -> str:
        id = deterministic_uuid(ddl) + "-ddl"  # 生成唯一标识符
        if self._check_if_embedding_exists(id=id, namespace=self.ddl_namespace):
            print(f"DDL with id: {id} 已存在于索引中,跳过...")
            return id
        self.Index.upsert(
            vectors=[(id, self.generate_embedding(ddl), {"ddl": ddl})],
            namespace=self.ddl_namespace,
        )
        return id

    def add_documentation(self, doc: str, **kwargs) -> str:
        id = deterministic_uuid(doc) + "-doc"  # 生成唯一标识符
        if self._check_if_embedding_exists(id=id, namespace=self.documentation_namespace):
            print(f"Documentation with id: {id} 已存在于索引中,跳过...")
            return id
        self.Index.upsert(
            vectors=[(id, self.generate_embedding(doc), {"documentation": doc})],
            namespace=self.documentation_namespace,
        )
        return id

    def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
        question_sql_json = json.dumps({"question": question, "sql": sql}, ensure_ascii=False)
        id = deterministic_uuid(question_sql_json) + "-sql"  # 生成唯一标识符
        if self._check_if_embedding_exists(id=id, namespace=self.sql_namespace):
            print(f"Question-SQL with id: {id} 已存在于索引中,跳过...")
            return id
        self.Index.upsert(
            vectors=[(id, self.generate_embedding(question_sql_json), {"sql": question_sql_json})],
            namespace=self.sql_namespace,
        )
        return id
  • add_ddl 方法添加 DDL 数据到指定命名空间,如果数据已经存在则跳过。
  • add_documentation 方法添加文档数据到指定命名空间,如果数据已经存在则跳过。
  • add_question_sql 方法添加问题和 SQL 数据到指定命名空间,如果数据已经存在则跳过。

查询相关数据的方法

	def get_related_ddl(self, question: str, **kwargs) -> list:
	    res = self.Index.query(
	        namespace=self.ddl_namespace,  # 使用DDL命名空间
	        vector=self.generate_embedding(question),  # 生成问题的嵌入向量
	        top_k=self.n_results,  # 返回的结果数量
	        include_values=True,
	        include_metadata=True,
	    )
	    # 如果查询有结果,返回匹配的DDL语句列表,否则返回空列表
	    return [match["metadata"]["ddl"] for match in res["matches"]] if res else []

获取相关文档方法

def get_related_documentation(self, question: str, **kwargs) -> list:
    res = self.Index.query(
        namespace=self.documentation_namespace,  # 使用文档命名空间
        vector=self.generate_embedding(question),  # 生成问题的嵌入向量
        top_k=self.n_results,  # 返回的结果数量
        include_values=True,
        include_metadata=True,
    )
    # 如果查询有结果,返回匹配的文档列表,否则返回空列表
    return (
        [match["metadata"]["documentation"] for match in res["matches"]]
        if res
        else []
    )

获取相似问题和SQL方法

def get_similar_question_sql(self, question: str, **kwargs) -> list:
    res = self.Index.query(
        namespace=self.sql_namespace,  # 使用SQL命名空间
        vector=self.generate_embedding(question),  # 生成问题的嵌入向量
        top_k=self.n_results,  # 返回的结果数量
        include_values=True,
        include_metadata=True,
    )
    # 如果查询有结果,解析并返回匹配的Question-SQL对列表,否则返回空列表
    return (
        [
            {
                key: value
                for key, value in json.loads(match["metadata"]["sql"]).items()
            }
            for match in res["matches"]
        ]
        if res
        else []
    )

获取训练数据方法

def get_training_data(self, **kwargs) -> pd.DataFrame:
    df = pd.DataFrame()
    namespaces = {
        "sql": self.sql_namespace,
        "ddl": self.ddl_namespace,
        "documentation": self.documentation_namespace,
    }

    for data_type, namespace in namespaces.items():
        data = self.Index.query(
            top_k=10000,  # Pinecone允许的最大结果数量
            namespace=namespace,  # 当前命名空间
            include_values=True,
            include_metadata=True,
            vector=[0.0] * self.dimensions,  # 使用零向量查询
        )

        if data is not None:
            id_list = [match["id"] for match in data["matches"]]
            content_list = [
                match["metadata"][data_type] for match in data["matches"]
            ]
            question_list = [
                (
                    json.loads(match["metadata"][data_type])["question"]
                    if data_type == "sql"
                    else None
                )
                for match in data["matches"]
            ]

            df_data = pd.DataFrame(
                {
                    "id": id_list,
                    "question": question_list,
                    "content": content_list,
                }
            )
            df_data["training_data_type"] = data_type
            df = pd.concat([df, df_data])

    # 返回包含所有训练数据的DataFrame
    return df

移除训练数据方法

def remove_training_data(self, id: str, **kwargs) -> bool:
    if id.endswith("-sql"):
        self.Index.delete(ids=[id], namespace=self.sql_namespace)
        return True
    elif id.endswith("-ddl"):
        self.Index.delete(ids=[id], namespace=self.ddl_namespace)
        return True
    elif id.endswith("-doc"):
        self.Index.delete(ids=[id], namespace=self.documentation_namespace)
        return True
    else:
        return False
    # 根据ID的后缀判断并从相应的命名空间中删除数据

生成嵌入方法

def generate_embedding(self, data: str, **kwargs) -> List[float]:
    embedding_model = TextEmbedding(model_name=self.fastembed_model)
    embedding = next(embedding_model.embed(data))
    return embedding.tolist()
    # 使用指定的Fastembed模型生成数据的嵌入向量,并返回生成的嵌入向量列表

qdrant

这个类实现了一个基于Qdrant的向量存储,用于存储和检索向量化的数据,如DDL(数据定义语言)、文档和SQL查询。通过该类,用户可以将文本数据转换为向量并存储在Qdrant中,并根据相似性进行检索。

初始化方法

def __init__(self, config={}):
    VannaBase.__init__(self, config=config)  # 调用父类的初始化方法
    client = config.get("client")  # 从配置中获取客户端

    if client is None:
        self._client = QdrantClient(  # 如果没有提供客户端,则创建一个新的QdrantClient实例
            location=config.get("location", None),
            url=config.get("url", None),
            prefer_grpc=config.get("prefer_grpc", False),
            https=config.get("https", None),
            api_key=config.get("api_key", None),
            timeout=config.get("timeout", None),
            path=config.get("path", None),
            prefix=config.get("prefix", None),
        )
    elif not isinstance(client, QdrantClient):
        raise TypeError(  # 如果提供的客户端不是QdrantClient实例,则抛出类型错误
            f"Unsupported client of type {client.__class__} was set in config"
        )
    else:
        self._client = client  # 使用提供的QdrantClient实例

    self.n_results = config.get("n_results", 10)  # 设置返回结果数量,默认为10
    self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")  # 设置嵌入模型
    self.collection_params = config.get("collection_params", {})  # 设置集合参数
    self.distance_metric = config.get("distance_metric", models.Distance.COSINE)  # 设置距离度量,默认为余弦距离
    self.documentation_collection_name = config.get(
        "documentation_collection_name", "documentation"
    )
    self.ddl_collection_name = config.get(
        "ddl_collection_name", "ddl"
    )
    self.sql_collection_name = config.get(
        "sql_collection_name", "sql"
    )

    self.id_suffixes = {
        self.ddl_collection_name: "ddl",
        self.documentation_collection_name: "doc",
        self.sql_collection_name: "sql",
    }

    self._setup_collections()  # 调用方法设置集合
  • 初始化 Qdrant_VectorStore 类实例。
  • 根据配置参数初始化 Qdrant 客户端。
  • 设置返回结果数量、嵌入模型、集合参数和距离度量。
  • 定义集合名称和 ID 后缀。
  • 调用 _setup_collections 方法设置集合。

添加问题-SQL对

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
    question_answer = "Question: {0}\n\nSQL: {1}".format(question, sql)  # 组合问题和SQL对
    id = deterministic_uuid(question_answer)  # 生成唯一ID

    self._client.upsert(
        self.sql_collection_name,
        points=[
            models.PointStruct(
                id=id,
                vector=self.generate_embedding(question_answer),  # 生成嵌入向量
                payload={
                    "question": question,
                    "sql": sql,
                },
            )
        ],
    )

    return self._format_point_id(id, self.sql_collection_name)  # 返回格式化的点ID
  • 添加问题和SQL对到SQL集合。
  • 生成问题和SQL对的嵌入并插入到集合中。
  • 返回格式化的点ID。

添加DDL

def add_ddl(self, ddl: str, **kwargs) -> str:
    id = deterministic_uuid(ddl)  # 生成唯一ID
    self._client.upsert(
        self.ddl_collection_name,
        points=[
            models.PointStruct(
                id=id,
                vector=self.generate_embedding(ddl),  # 生成嵌入向量
                payload={
                    "ddl": ddl,
                },
            )
        ],
    )
    return self._format_point_id(id, self.ddl_collection_name)  # 返回格式化的点ID
  • 添加DDL到DDL集合。
  • 生成DDL的嵌入并插入到集合中。
  • 返回格式化的点ID。

添加文档

def add_documentation(self, documentation: str, **kwargs) -> str:
    id = deterministic_uuid(documentation)  # 生成唯一ID

    self._client.upsert(
        self.documentation_collection_name,
        points=[
            models.PointStruct(
                id=id,
                vector=self.generate_embedding(documentation),  # 生成嵌入向量
                payload={
                    "documentation": documentation,
                },
            )
        ],
    )

    return self._format_point_id(id, self.documentation_collection_name)  # 返回格式化的点ID
  • 添加文档到文档集合。
  • 生成文档的嵌入并插入到集合中。
  • 返回格式化的点ID。

获取训练数据

def get_training_data(self, **kwargs) -> pd.DataFrame:
    df = pd.DataFrame()  # 初始化空的DataFrame

    if sql_data := self._get_all_points(self.sql_collection_name):  # 获取SQL集合中的所有点
        question_list = [data.payload["question"] for data in sql_data]
        sql_list = [data.payload["sql"] for data in sql_data]
        id_list = [
            self._format_point_id(data.id, self.sql_collection_name)
            for data in sql_data
        ]

        df_sql = pd.DataFrame(
            {
                "id": id_list,
                "question": question_list,
                "content": sql_list,
            }
        )

        df_sql["training_data_type"] = "sql"

        df = pd.concat([df, df_sql])  # 合并到主DataFrame

    if ddl_data := self._get_all_points(self.ddl_collection_name):  # 获取DDL集合中的所有点
        ddl_list = [data.payload["ddl"] for data in ddl_data]
        id_list = [
            self._format_point_id(data.id, self.ddl_collection_name)
            for data in ddl_data
        ]

        df_ddl = pd.DataFrame(
            {
                "id": id_list,
                "question": [None for _ in ddl_list],
                "content": ddl_list,
            }
        )

        df_ddl["training_data_type"] = "ddl"

        df = pd.concat([df, df_ddl])  # 合并到主DataFrame

    if doc_data := self._get_all_points(self.documentation_collection_name):  # 获取文档集合中的所有点
        document_list = [data.payload["documentation"] for data in doc_data]
        id_list = [
            self._format_point_id(data.id, self.documentation_collection_name)
            for data in doc_data
        ]

        df_doc = pd.DataFrame(
            {
                "id": id_list,
                "question": [None for _ in document_list],
                "content": document_list,
            }
        )

        df_doc["training_data_type"] = "documentation"

        df = pd.concat([df, df_doc])  # 合并到主DataFrame

    return df  # 返回包含所有训练数据的DataFrame
  • 获取所有训练数据并返回一个包含所有数据的DataFrame。
  • 分别从SQL集合、DDL集合和文档集合中获取数据。
  • 将数据组合成DataFrame格式。

移除训练数据

def remove_training_data(self, id: str, **kwargs) -> bool:
    try:
        id, collection_name = self._parse_point_id(id)  # 解析点ID获取集合名称
        res = self._client.delete(collection_name, points_selector=[id])
        return True
    except ValueError:
        return False
  • 根据ID移除训练数据。
  • 解析点ID获取集合名称。
  • 从指定集合中删除数据。

移除集合

def remove_collection(self, collection_name: str) -> bool:
    if collection_name in self.id_suffixes.keys():
        self._client.delete_collection(collection_name)  # 删除集合
        self._setup_collections()  # 重新设置集合
        return True
    else:
        return False
  • 移除指定的集合并重置集合为空状态。
  • 根据集合名称删除集合并重新设置集合。

生成嵌入

def generate_embedding(self, data: str, **kwargs) -> List[float]:
    embedding_model = self._client._get_or_init_model(
        model_name=self.fastembed_model
    )
    embedding = next(embedding_model.embed(data))

    return embedding.tolist()  # 返回生成的嵌入向量
  • 生成数据的嵌入向量。
  • 使用指定的Fastembed模型生成数据的嵌入向量并返回。

获取相似问题和SQL方法

def get_similar_question_sql(self, question: str, **kwargs) -> list:
    results = self._client.search(
        self.sql_collection_name,  # 使用SQL集合名称
        query_vector=self.generate_embedding(question),  # 生成问题的嵌入向量
        limit=self.n_results,  # 返回的结果数量限制
        with_payload=True,
    )

    return [dict(result.payload) for result in results]  # 返回结果列表,其中包含查询的Payload

获取相关DDL方法

def get_related_ddl(self, question: str, **kwargs) -> list:
    results = self._client.search(
        self.ddl_collection_name,  # 使用DDL集合名称
        query_vector=self.generate_embedding(question),  # 生成问题的嵌入向量
        limit=self.n_results,  # 返回的结果数量限制
        with_payload=True,
    )

    # 返回结果列表,其中包含DDL字段
    return [result.payload["ddl"] for result in results]
  • 根据输入的问题,在DDL集合中查询相关的DDL,并返回结果。

获取相关文档方法

def get_related_documentation(self, question: str, **kwargs) -> list:
    results = self._client.search(
        self.documentation_collection_name,  # 使用文档集合名称
        query_vector=self.generate_embedding(question),  # 生成问题的嵌入向量
        limit=self.n_results,  # 返回的结果数量限制
        with_payload=True,
    )

    # 返回结果列表,其中包含文档字段
    return [result.payload["documentation"] for result in results]
  • 根据输入的问题,在文档集合中查询相关的文档,并返回结果。

嵌入维度

@cached_property
def embeddings_dimension(self):
    return len(self.generate_embedding("ABCDEF"))  # 返回嵌入向量的维度
  • 返回嵌入向量的维度。

获取所有点的方法

def _get_all_points(self, collection_name: str):
    results: List[models.Record] = []
    next_offset = None
    stop_scrolling = False
    while not stop_scrolling:
        records, next_offset = self._client.scroll(
            collection_name,  # 使用指定集合名称
            limit=SCROLL_SIZE,  # 每次滚动获取的记录数量
            offset=next_offset,  # 当前滚动偏移量
            with_payload=True,
            with_vectors=False,
        )
        stop_scrolling = next_offset is None or (
            isinstance(next_offset, grpc.PointId)
            and next_offset.num == 0
            and next_offset.uuid == ""
        )

        results.extend(records)

    return results  # 返回所有记录的列表
  • 获取指定集合中的所有点记录。

设置集合的方法

def _setup_collections(self):
    if not self._client.collection_exists(self.sql_collection_name):
        self._client.create_collection(
            collection_name=self.sql_collection_name,  # 创建SQL集合
            vectors_config=models.VectorParams(
                size=self.embeddings_dimension,  # 嵌入向量的维度
                distance=self.distance_metric,  # 使用的距离度量
            ),
            **self.collection_params,
        )

    if not self._client.collection_exists(self.ddl_collection_name):
        self._client.create_collection(
            collection_name=self.ddl_collection_name,  # 创建DDL集合
            vectors_config=models.VectorParams(
                size=self.embeddings_dimension,  # 嵌入向量的维度
                distance=self.distance_metric,  # 使用的距离度量
            ),
            **self.collection_params,
        )
    if not self._client.collection_exists(self.documentation_collection_name):
        self._client.create_collection(
            collection_name=self.documentation_collection_name,  # 创建文档集合
            vectors_config=models.VectorParams(
                size=self.embeddings_dimension,  # 嵌入向量的维度
                distance=self.distance_metric,  # 使用的距离度量
            ),
            **self.collection_params,
        )
  • 设置SQL、DDL和文档集合,如果集合不存在则创建集合。

格式化点ID的方法

def _format_point_id(self, id: str, collection_name: str) -> str:
    return "{0}-{1}".format(id, self.id_suffixes[collection_name])  # 返回格式化的点ID
  • 格式化点ID,添加集合后缀。

解析点ID的方法

def _parse_point_id(self, id: str) -> Tuple[str, str]:
    id, curr_suffix = id.rsplit("-", 1)
    for collection_name, suffix in self.id_suffixes.items():
        if curr_suffix == suffix:
            return id, collection_name
    raise ValueError(f"Invalid id {id}")  # 如果ID无效则抛出异常
  • 解析点ID,返回ID和集合名称。

types

该代码定义了一些数据类和一个训练计划类,用于表示各种数据结构和训练计划。

  1. 数据类

    • 使用@dataclass装饰器定义了一系列数据类,如StatusQuestionListFullQuestionDocument等。
    • 这些类用于描述和存储不同类型的数据,如问题、答案、组织信息、数据结果等。
  2. 训练计划类

    • TrainingPlanItem类表示训练计划中的一个项目。
    • TrainingPlan类表示一个训练计划,可以获取计划的概要并移除不需要的项目。

数据类和训练计划类解析

下面是提供的代码的详细解析,包括逐行注释:

from __future__ import annotations  # 允许在当前文件中使用未来版本的特性

from dataclasses import dataclass  # 从dataclasses模块导入dataclass装饰器
from typing import Dict, List, Union  # 导入类型注解

@dataclass
class Status:
    success: bool  # 操作是否成功
    message: str  # 返回的信息

@dataclass
class StatusWithId:
    success: bool  # 操作是否成功
    message: str  # 返回的信息
    id: str  # 关联的ID

@dataclass
class QuestionList:
    questions: List[FullQuestionDocument]  # 包含问题文档的列表

@dataclass
class FullQuestionDocument:
    id: QuestionId  # 问题ID
    question: Question  # 问题内容
    answer: SQLAnswer | None  # SQL答案(如果有)
    data: DataResult | None  # 数据结果(如果有)
    plotly: PlotlyResult | None  # Plotly结果(如果有)

@dataclass
class QuestionSQLPair:
    question: str  # 问题内容
    sql: str  # 对应的SQL查询
    tag: Union[str, None]  # 标签(可选)

@dataclass
class Organization:
    name: str  # 组织名称
    user: str | None  # 用户名称(可选)
    connection: Connection | None  # 连接(可选)

@dataclass
class OrganizationList:
    organizations: List[str]  # 组织名称列表

@dataclass
class QuestionStringList:
    questions: List[str]  # 问题内容列表

@dataclass
class Visibility:
    visibility: bool  # 可见性状态

@dataclass
class UserEmail:
    email: str  # 用户电子邮件

@dataclass
class NewOrganization:
    org_name: str  # 新组织名称
    db_type: str  # 数据库类型

@dataclass
class NewOrganizationMember:
    org_name: str  # 组织名称
    email: str  # 成员电子邮件
    is_admin: bool  # 是否为管理员

@dataclass
class UserOTP:
    email: str  # 用户电子邮件
    otp: str  # 一次性密码

@dataclass
class ApiKey:
    key: str  # API密钥

@dataclass
class QuestionId:
    id: str  # 问题ID

@dataclass
class Question:
    question: str  # 问题内容

@dataclass
class QuestionCategory:
    question: str  # 问题内容
    category: str  # 问题类别

    # 预定义的类别常量
    NO_SQL_GENERATED = "No SQL Generated"
    SQL_UNABLE_TO_RUN = "SQL Unable to Run"
    BOOTSTRAP_TRAINING_QUERY = "Bootstrap Training Query"
    SQL_RAN = "SQL Ran Successfully"
    FLAGGED_FOR_REVIEW = "Flagged for Review"
    REVIEWED_AND_APPROVED = "Reviewed and Approved"
    REVIEWED_AND_REJECTED = "Reviewed and Rejected"
    REVIEWED_AND_UPDATED = "Reviewed and Updated"

@dataclass
class AccuracyStats:
    num_questions: int  # 问题数量
    data: Dict[str, int]  # 各类别的问题数量

@dataclass
class Followup:
    followup: str  # 后续问题

@dataclass
class QuestionEmbedding:
    question: Question  # 问题内容
    embedding: List[float]  # 嵌入向量

@dataclass
class Connection:
    # TODO: 实现连接类
    pass

@dataclass
class SQLAnswer:
    raw_answer: str  # 原始答案
    prefix: str  # 前缀
    postfix: str  # 后缀
    sql: str  # SQL查询

@dataclass
class Explanation:
    explanation: str  # 解释内容

@dataclass
class DataResult:
    question: str | None  # 问题内容(可选)
    sql: str | None  # SQL查询(可选)
    table_markdown: str  # 表格的Markdown表示
    error: str | None  # 错误信息(可选)
    correction_attempts: int  # 修正尝试次数

@dataclass
class PlotlyResult:
    plotly_code: str  # Plotly代码

@dataclass
class WarehouseDefinition:
    name: str  # 仓库名称
    tables: List[TableDefinition]  # 表定义列表

@dataclass
class TableDefinition:
    schema_name: str  # 模式名称
    table_name: str  # 表名称
    ddl: str | None  # DDL语句(可选)
    columns: List[ColumnDefinition]  # 列定义列表

@dataclass
class ColumnDefinition:
    name: str  # 列名称
    type: str  # 列类型
    is_primary_key: bool  # 是否为主键
    is_foreign_key: bool  # 是否为外键
    foreign_key_table: str  # 外键表
    foreign_key_column: str  # 外键列

@dataclass
class Diagram:
    raw: str  # 原始内容
    mermaid_code: str  # Mermaid代码

@dataclass
class StringData:
    data: str  # 字符串数据

@dataclass
class DataFrameJSON:
    data: str  # JSON格式的数据

@dataclass
class TrainingData:
    questions: List[dict]  # 问题列表
    ddl: List[str]  # DDL语句列表
    documentation: List[str]  # 文档列表

@dataclass
class TrainingPlanItem:
    item_type: str  # 项目类型
    item_group: str  # 项目组
    item_name: str  # 项目名称
    item_value: str  # 项目值

    def __str__(self):
        if self.item_type == self.ITEM_TYPE_SQL:
            return f"Train on SQL: {self.item_group} {self.item_name}"
        elif self.item_type == self.ITEM_TYPE_DDL:
            return f"Train on DDL: {self.item_group} {self.item_name}"
        elif self.item_type == self.ITEM_TYPE_IS:
            return f"Train on Information Schema: {self.item_group} {self.item_name}"

    ITEM_TYPE_SQL = "sql"
    ITEM_TYPE_DDL = "ddl"
    ITEM_TYPE_IS = "is"

class TrainingPlan:
    """
    A class representing a training plan. You can see what's in it, and remove items from it that you don't want trained.
    **Example:**
    ```python
    plan = vn.get_training_plan()
    plan.get_summary()
    ```
    """
    _plan: List[TrainingPlanItem]

    def __init__(self, plan: List[TrainingPlanItem]):
        self._plan = plan  # 初始化训练计划

    def __str__(self):
        return "\n".join(self.get_summary())  # 返回训练计划的字符串表示

    def __repr__(self):
        return self.__str__()  # 返回训练计划的字符串表示

    def get_summary(self) -> List[str]:
        """
        **Example:**
        ```python
        plan = vn.get_training_plan()
        plan.get_summary()
        ```
        Get a summary of the training plan.
        Returns:
            List[str]: A list of strings describing the training plan.
        """
        return [f"{item}" for item in self._plan]  # 返回训练计划的概要

    def remove_item(self, item: str):
        """
        **Example:**
        ```python
        plan = vn.get_training_plan()
        plan.remove_item("Train on SQL: What is the average salary of employees?")
        ```
        Remove an item from the training plan.
        Args:
            item (str): The item to remove.
        """
        for plan_item in self._plan:
            if str(plan_item) == item:
                self._plan.remove(plan_item)  # 从训练计划中移除指定项
                break

vannadb_vector

VannaDB_VectorStore类

该类实现了一个基于Vanna API的向量存储,用于存储和检索向量化的数据,需要与Vanna官网通信。
这个类通过HTTP请求与Vanna的API进行交互,包括以下操作:

  • RPC调用:通过POST请求与Vanna的RPC端点通信,以调用各种方法(如add_sql、get_training_data等)。
  • GraphQL调用:通过POST请求与Vanna的GraphQL端点通信,以执行查询和变更(如获取函数、创建函数、更新函数等)。

初始化与配置

def __init__(self, vanna_model: str, vanna_api_key: str, config=None):
    VannaBase.__init__(self, config=config)  # 调用父类的初始化方法
    self._model = vanna_model  # 设置模型名称
    self._api_key = vanna_api_key  # 设置API密钥

    self._endpoint = (
        "https://ask.vanna.ai/rpc"
        if config is None or "endpoint" not in config
        else config["endpoint"]
    )
    self.related_training_data = {}  # 初始化相关训练数据字典
    self._graphql_endpoint = "https://functionrag.com/query"  # 设置GraphQL端点
    self._graphql_headers = {
        "Content-Type": "application/json",
        "API-KEY": self._api_key,
        "NAMESPACE": self._model,
    }
  • 初始化VannaDB_VectorStore类实例。
  • 根据配置参数初始化API端点和相关头信息。
  • 设置模型名称和API密钥。

RPC调用方法

def _rpc_call(self, method, params):
    if method != "list_orgs":  # 设置请求头
        headers = {
            "Content-Type": "application/json",
            "Vanna-Key": self._api_key,
            "Vanna-Org": self._model,
        }
    else:
        headers = {
            "Content-Type": "application/json",
            "Vanna-Key": self._api_key,
            "Vanna-Org": "demo-tpc-h",
        }

    data = {
        "method": method,
        "params": [self._dataclass_to_dict(obj) for obj in params],  # 将参数转换为字典
    }

    response = requests.post(self._endpoint, headers=headers, data=json.dumps(data))  # 发送POST请求
    return response.json()  # 返回JSON响应
  • 发送RPC请求到Vanna API。
  • 设置请求头和请求数据。
  • 返回API响应的JSON数据。

数据类转换方法

def _dataclass_to_dict(self, obj):
    return dataclasses.asdict(obj)  # 将数据类对象转换为字典
  • 将数据类对象转换为字典格式。

获取所有函数方法

def get_all_functions(self) -> list:
    query = """
        {
            get_all_sql_functions {
                function_name
                description
                post_processing_code_template
                arguments {
                    name
                    description
                    general_type
                    is_user_editable
                    available_values
                }
                sql_template
            }
        }
    """

    response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query})  # 发送GraphQL请求
    response_json = response.json()  # 获取JSON响应
    if response.status_code == 200 and 'data' in response_json and 'get_all_sql_functions' in response_json['data']:
        self.log(response_json['data']['get_all_sql_functions'])  # 记录日志
        resp = response_json['data']['get_all_sql_functions']

        print(resp)

        return resp  # 返回函数列表
    else:
        raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
  • 发送GraphQL查询请求以获取所有SQL函数。
  • 返回函数列表。

获取函数方法

def get_function(self, question: str, additional_data: dict = {}) -> dict:
    query = """
    query GetFunction($question: String!, $staticFunctionArguments: [StaticFunctionArgument]) {
        get_and_instantiate_function(question: $question, static_function_arguments: $staticFunctionArguments) {
            ... on SQLFunction {
            function_name
            description
            post_processing_code_template
            instantiated_post_processing_code
            arguments {
                name
                description
                general_type
                is_user_editable
                instantiated_value
                available_values
            }
            sql_template
            instantiated_sql
        }
        }
    }
    """
    static_function_arguments = [{"name": key, "value": str(value)} for key, value in additional_data.items()]
    variables = {"question": question, "staticFunctionArguments": static_function_arguments}
    response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})  # 发送GraphQL请求
    response_json = response.json()  # 获取JSON响应
    if response.status_code == 200 and 'data' in response_json and 'get_and_instantiate_function' in response_json['data']:
        self.log(response_json['data']['get_and_instantiate_function'])  # 记录日志
        resp = response_json['data']['get_and_instantiate_function']

        print(resp)

        return resp  # 返回函数信息
    else:
        raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
  • 发送GraphQL查询请求以获取特定问题的函数。
  • 返回函数信息。

创建函数方法

def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
    query = """
    mutation CreateFunction($question: String!, $sql: String!, $plotly_code: String!) {
        generate_and_create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
            function_name
            description
            arguments {
                name
                description
                general_type
                is_user_editable
            }
            sql_template
            post_processing_code_template
        }
    }
    """
    variables = {"question": question, "sql": sql, "plotly_code": plotly_code}
    response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})  # 发送GraphQL请求
    response_json = response.json()  # 获取JSON响应
    if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'generate_and_create_sql_function' in response_json['data']:
        resp = response_json['data']['generate_and_create_sql_function']

        print(resp)

        return resp  # 返回新创建的函数信息
    else:
        raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
  • 发送GraphQL变更请求以创建新的SQL函数。
  • 返回新创建的函数信息。

更新函数方法

def update_function(self, old_function_name: str, updated_function: dict) -> bool:
    mutation = """
    mutation UpdateSQLFunction($input: SQLFunctionUpdate!) {
        update_sql_function(input: $input)
    }
    """

    SQLFunctionUpdate = {
        'function_name', 'description', 'arguments', 'sql_template', 'post_processing_code_template'
    }

    ArgumentKeys = {'name', 'general_type', 'description', 'is_user_editable', 'available_values'}

    def validate_arguments(args):
        return [
            {key: arg[key] for key in arg if key in ArgumentKeys}
            for arg in args
        ]

    updated_function = {key: value for key, value in updated_function.items() if key in SQLFunctionUpdate}

    if 'arguments' in updated_function:
        updated_function['arguments'] = validate_arguments(updated_function['arguments'])

    variables = {
        "input": {
            "old_function_name": old_function_name,
            **updated_function
        }
    }

    print("variables", variables)

    response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})  # 发送GraphQL请求
    response_json = response.json()  # 获取JSON响应
    if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'update_sql_function' in response_json['data']:
        return response_json['data']['update_sql_function']  # 返回更新结果
    else:
        raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
  • 发送GraphQL变更请求以更新现有的SQL函数。
  • 返回更新结果。

删除函数方法

def delete_function(self, function_name: str) -> bool:
    mutation = """
    mutation DeleteSQLFunction($function_name: String!) {
        delete_sql_function(function_name: $function_name)
    }
    """
    variables = {"function_name": function_name}  # 设置变量
    response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})  # 发送POST请求
    response_json = response.json()  # 获取JSON响应
    if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'delete_sql_function' in response_json['data']:
        return response_json['data']['delete_sql_function']  # 返回删除结果
    else:
        raise Exception

创建模型方法

def create_model(self, model: str, **kwargs) -> bool:
    model = sanitize_model_name(model)  # 清理模型名称
    params = [NewOrganization(org_name=model, db_type="")]

    d = self._rpc_call(method="create_org", params=params)  # 调用RPC创建组织

    if "result" not in d:
        return False

    status = Status(**d["result"])

    return status.success  # 返回创建状态
  • 调用RPC创建新的模型组织。
  • 返回创建状态。

获取模型方法

def get_models(self) -> list:
    d = self._rpc_call(method="list_my_models", params=[])  # 调用RPC列出模型

    if "result" not in d:
        return []

    orgs = OrganizationList(**d["result"])

    return orgs.organizations  # 返回模型列表
  • 调用RPC获取用户的所有模型。
  • 返回模型列表。

生成嵌入方法

def generate_embedding(self, data: str, **kwargs) -> list[float]:
    pass  # 在服务器端生成嵌入
  • 生成嵌入向量的方法占位符,实际实现由服务器端完成。

添加问题和SQL方法

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
    tag = kwargs.get("tag", "Manually Trained")  # 获取标签,默认为“Manually Trained”

    params = [QuestionSQLPair(question=question, sql=sql, tag=tag)]

    d = self._rpc_call(method="add_sql", params=params)  # 调用RPC添加问题和SQL对

    if "result" not in d:
        raise Exception("Error adding question and SQL pair", d)

    status = StatusWithId(**d["result"])

    return status.id  # 返回新添加的ID
  • 调用RPC添加问题和SQL对。
  • 返回新添加的ID。

添加DDL方法

def add_ddl(self, ddl: str, **kwargs) -> str:
    params = [StringData(data=ddl)]

    d = self._rpc_call(method="add_ddl", params=params)  # 调用RPC添加DDL

    if "result" not in d:
        raise Exception("Error adding DDL", d)

    status = StatusWithId(**d["result"])

    return status.id  # 返回新添加的ID
  • 调用RPC添加DDL语句。
  • 返回新添加的ID。

添加文档方法

def add_documentation(self, documentation: str, **kwargs) -> str:
    params = [StringData(data=documentation)]

    d = self._rpc_call(method="add_documentation", params=params)  # 调用RPC添加文档

    if "result" not in d:
        raise Exception("Error adding documentation", d)

    status = StatusWithId(**d["result"])

    return status.id  # 返回新添加的ID
  • 调用RPC添加文档。
  • 返回新添加的ID。

获取训练数据方法

def get_training_data(self, **kwargs) -> pd.DataFrame:
    params = []

    d = self._rpc_call(method="get_training_data", params=params)  # 调用RPC获取训练数据

    if "result" not in d:
        return None

    training_data = DataFrameJSON(**d["result"])

    df = pd.read_json(StringIO(training_data.data))  # 将JSON数据读取为DataFrame

    return df  # 返回DataFrame
  • 调用RPC获取训练数据。
  • 返回训练数据的DataFrame。

移除训练数据方法

def remove_training_data(self, id: str, **kwargs) -> bool:
    params = [StringData(data=id)]

    d = self._rpc_call(method="remove_training_data", params=params)  # 调用RPC移除训练数据

    if "result" not in d:
        raise Exception("Error removing training data")

    status = Status(**d["result"])

    if not status.success:
        raise Exception(f"Error removing training data: {status.message}")

    return status.success  # 返回移除状态
  • 调用RPC移除指定ID的训练数据。
  • 返回移除状态。

获取缓存的相关训练数据方法

def get_related_training_data_cached(self, question: str) -> TrainingData:
    params = [Question(question=question)]

    d = self._rpc_call(method="get_related_training_data", params=params)  # 调用RPC获取相关训练数据

    if "result" not in d:
        return None

    training_data = TrainingData(**d["result"])

    self.related_training_data[question] = training_data  # 缓存相关训练数据

    return training_data  # 返回训练数据
  • 调用RPC获取相关训练数据,并将其缓存。
  • 返回训练数据。

获取相似问题和SQL方法

def get_similar_question_sql(self, question: str, **kwargs) -> list:
    if question in self.related_training_data:
        training_data = self.related_training_data[question]
    else:
        training_data = self.get_related_training_data_cached(question)  # 从缓存获取相关训练数据

    return training_data.questions  # 返回相似问题和SQL对
  • 获取与指定问题相似的问题和SQL对。
  • 如果缓存中存在相关数据则直接使用,否则通过RPC获取并缓存。

获取相关DDL方法

def get_related_ddl(self, question: str, **kwargs) -> list:
    if question in self.related_training_data:
        training_data = self.related_training_data[question]
    else:
        training_data = self.get_related_training_data_cached(question)  # 从缓存获取相关训练数据

    return training_data.ddl  # 返回相关DDL
  • 获取与指定问题相关的DDL。
  • 如果缓存中存在相关数据则直接使用,否则通过RPC获取并缓存。

获取相关文档方法

def get_related_documentation(self, question: str, **kwargs) -> list:
    if question in self.related_training_data:
        training_data = self.related_training_data[question]
    else:
        training_data = self.get_related_training_data_cached(question)  # 从缓存获取相关训练数据

    return training_data.documentation  # 返回相关文档
  • 获取与指定问题相关的文档。
  • 如果缓存中存在相关数据则直接使用,否则通过RPC获取并缓存。

vllm

这个类提供了与VLLM(假设为一个虚拟语言模型)进行交互的功能。通过配置参数初始化类实例,并提供多种方法以便与VLLM进行通信和处理。

  • 该类通过配置参数初始化,并提供与VLLM服务交互的方法。
  • 可以生成系统消息、用户消息和助手消息。
  • 提供提取SQL查询和生成SQL的方法。
  • 通过提交提示与VLLM服务通信,获取生成的消息内容。

初始化方法

def __init__(self, config=None):
    if config is None or "vllm_host" not in config:
        self.host = "http://localhost:8000"  # 设置默认主机地址
    else:
        self.host = config["vllm_host"]  # 使用配置中的主机地址

    if config is None or "model" not in config:
        raise ValueError("check the config for vllm")  # 检查配置中的模型参数
    else:
        self.model = config["model"]  # 设置模型名称

    if "auth-key" in config:
        self.auth_key = config["auth-key"]  # 设置认证密钥
    else:
        self.auth_key = None  # 如果没有认证密钥,则设为None
  • 根据配置参数初始化类实例。
  • 检查并设置主机地址、模型名称和认证密钥。

消息生成方法

def system_message(self, message: str) -> any:
    return {"role": "system", "content": message}  # 生成系统消息

def user_message(self, message: str) -> any:
    return {"role": "user", "content": message}  # 生成用户消息

def assistant_message(self, message: str) -> any:
    return {"role": "assistant", "content": message}  # 生成助手消息
  • 生成系统消息、用户消息和助手消息,分别返回不同角色的消息字典。

SQL查询提取方法

def extract_sql_query(self, text):
    """
    提取第一个SQL语句,该语句在'`select`'之后,忽略大小写,
    匹配直到第一个分号、三个反引号或字符串末尾,
    并在提取的字符串中去除三个反引号。

    Args:
    - text (str): 要搜索的字符串。

    Returns:
    - str: 提取的第一个SQL语句,去除三个反引号,如果没有匹配则返回空字符串。
    """
    # 正则表达式查找'`select`'(忽略大小写),并捕获直到';'、'```'或字符串末尾
    pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)

    match = pattern.search(text)
    if match:
        # 如果存在匹配,则去除匹配字符串中的三个反引号
        return match.group(0).replace("```", "")
    else:
        return text
  • 提取第一个SQL语句,该语句在select之后,忽略大小写,匹配直到第一个分号、三个反引号或字符串末尾,并在提取的字符串中去除三个反引号。

生成SQL方法

def generate_sql(self, question: str, **kwargs) -> str:
    sql = super().generate_sql(question, **kwargs)  # 调用父类的generate_sql方法

    sql = sql.replace("\\_", "_")  # 将"\_"替换为"_"
    sql = sql.replace("\\", "")  # 去除反斜杠

    return self.extract_sql_query(sql)  # 提取并返回SQL查询
  • 生成SQL查询,处理特殊字符并提取SQL语句。

提交提示方法

def submit_prompt(self, prompt, **kwargs) -> str:
    url = f"{self.host}/v1/chat/completions"
    data = {
        "model": self.model,
        "stream": False,
        "messages": prompt,
    }

    if self.auth_key is not None:
        headers = {
            'Content-Type': 'application/json',
            'Authorization': f'Bearer {self.auth_key}'
        }
        response = requests.post(url, headers=headers, json=data)  # 发送带有认证头的POST请求
    else:
        response = requests.post(url, json=data)  # 发送不带认证头的POST请求

    response_dict = response.json()  # 获取JSON响应

    self.log(response.text)  # 记录响应日志

    return response_dict['choices'][0]['message']['content']  # 返回生成的消息内容
  • 提交提示到VLLM服务,处理响应并返回生成的消息内容。
<think>好的,我现在需要帮助用户了解如何在本地环境中部署Vanna。首先,我得回忆一下用户提供的引用内容,特别是引用[2]中提到的本地部署vanna对接数据库的步骤。根据用户的问题,他们可能需要详细的步骤指南,包括安装、配置和基本使用。 首先,我应该检查系统级指令,确保所有数学公式都正确使用$...$或$$...$$格式,但这个问题可能不涉及数学表达式,所以可能不需要处理这部分。然后,确保用中文回答,结构清晰,分步骤说明。 接下来,根据引用[2],本地部署vanna需要对目录进行分析,所以可能需要先安装必要的依赖,比如Python环境、vanna库以及数据库驱动。用户可能需要知道如何安装这些组件,比如使用pip安装vanna和相应的数据库适配器,如psycopg2用于PostgreSQL。 然后,配置Vanna可能涉及设置数据库连接,用户需要创建Vanna实例并连接到数据库。这部分需要详细说明如何初始化模型,设置API密钥(如果有的话),以及配置数据库连接参数。 之后,可能需要训练模型,让Vanna理解自然语言到SQL的转换。引用中提到需要训练数据,比如DDL语句、文档或现有查询,所以需要指导用户如何收集和输入这些数据到Vanna中。 最后,用户可能需要验证部署是否成功,例如通过自然语言提问生成SQL并执行查询。这里需要提供一个简单的示例,比如用户输入问题,Vanna生成SQL,然后执行并返回结果。 还需要确保回答中引用相关的引用内容,比如在步骤末尾添加[^2]等。同时,生成相关问题,比如如何优化Vanna的准确性、处理复杂查询或整合到现有系统等。 需要检查是否有遗漏的步骤,比如是否需要额外的配置或依赖,比如设置环境变量或处理权限问题。可能还需要提到如何处理常见的错误,比如连接失败或训练数据不足的情况。 另外,参考用户提供的示例回答结构,比如快速排序的示例,可能需要用类似的格式,分步骤说明,代码块用正确的语法高亮,比如Python代码。同时,确保语言简洁,避免技术术语过多,适合非技术人员理解。 总结步骤应该是:安装依赖、配置Vanna、训练模型、使用示例。每个步骤需要详细但不过于复杂,确保用户能逐步跟随操作。同时,注意引用来源,确保信息准确可靠。</think>### 本地部署 Vanna 的步骤指南 #### 1. **环境准备** - 安装 Python(推荐 3.8+ 版本)和包管理工具 `pip`。 - 创建虚拟环境(可选但建议): ```bash python -m venv vanna-env source vanna-env/bin/activate # Linux/macOS vanna-env\Scripts\activate # Windows ``` - 安装 Vanna 和相关依赖: ```bash pip install vanna psycopg2-binary # 以 PostgreSQL 为例,其他数据库需替换驱动(如 pymysql) ``` #### 2. **初始化 Vanna 实例** - 创建 Python 脚本(如 `vanna_demo.py`),配置 Vanna 并连接数据库: ```python from vanna.local import LocalContext_OpenAI # 初始化模型(使用本地模式或 OpenAI API) vn = LocalContext_OpenAI( config={'api_key': 'YOUR_OPENAI_KEY'} # 若使用本地模式可忽略 API 密钥 ) # 连接数据库(以 PostgreSQL 为例) vn.connect_to_postgres( host='localhost', dbname='your_database', user='your_user', password='your_password', port=5432 ) ``` #### 3. **训练模型** - **方式 1:导入数据库 DDL 语句** 提供表结构定义文件(如 `schema.sql`): ```python vn.train(ddl=""" CREATE TABLE employees ( id INT PRIMARY KEY, name VARCHAR(100), department VARCHAR(50) ); """) ``` - **方式 2:加载现有 SQL 查询** 添加历史查询示例以增强模型理解: ```python vn.train(sql="SELECT name, department FROM employees WHERE id = 1001;") ``` - **方式 3:使用文档训练(可选)** 上传业务文档(如 CSV 或 Markdown)补充语义信息: ```python vn.train(documentation="部门表包含员工 ID、姓名和所属部门字段") ``` #### 4. **验证部署** - 通过自然语言生成 SQL 并执行: ```python question = "列出技术部的所有员工姓名" sql = vn.generate_sql(question) # 生成 SQL print("生成的 SQL:", sql) result = vn.run_sql(sql) # 执行查询 print("查询结果:", result) ``` - 输出示例: ``` 生成的 SQL: SELECT name FROM employees WHERE department = '技术部' 查询结果: [('张三',), ('李四',)] ``` #### 5. **扩展配置(可选)** - **自定义提示模板**:调整 `vn.train` 中的提示词优化 SQL 生成逻辑。 - **接入其他数据库**:修改 `connect_to_postgres` 为 MySQL/Snowflake 等方法。 - **日志与调试**:启用 `vn.log=true` 查看生成过程的中间步骤[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值