vanna学习日志(二)


前言

vanna可实现自然语言转SQL,尝试本地部署vanna对接数据库,将自然语言转成标准的SQL对数据库进行查询。本文先对vanna源码进行分析

一、vanna源码分析

chromadb_vector.py:

以下代码是用于管理和操作与 ChromaDB 向量存储相关的数据。该文件定义了ChromaDB_VectorStore 类,该类继承自 VannaBase,用于处理以下任务:

1.初始化 ChromaDB 客户端:根据配置创建持久性或内存中的 ChromaDB 客户端。
2.创建或获取集合:包括 documentation、ddl 和 sql 三个集合,用于存储相应的数据。
3.生成嵌入向量:使用嵌入函数生成数据的向量表示。
4.添加数据:将问题与 SQL、DDL 或文档添加到相应的集合中。
5.获取训练数据:从集合中提取数据并返回 pandas DataFrame。
6.删除数据:根据 ID 从集合中删除特定数据。
7.删除集合:重置集合为空状态。
8.查询相似数据:从集合中查询与问题相关的 SQL、DDL 或文档。
import json  # 导入 json 模块,用于处理 JSON 数据
from typing import List  # 从 typing 模块导入 List 类型,用于类型注解

import chromadb  # 导入 chromadb 模块,用于操作 Chroma 数据库
import pandas as pd  # 导入 pandas 模块,用于数据处理
from chromadb.config import Settings  # 从 chromadb.config 模块导入 Settings 类
from chromadb.utils import embedding_functions  # 从 chromadb.utils 模块导入 embedding_functions

from ..base import VannaBase  # 从上一级目录的 base 模块导入 VannaBase 类
from ..utils import deterministic_uuid  # 从上一级目录的 utils 模块导入 deterministic_uuid 函数

default_ef = embedding_functions.DefaultEmbeddingFunction()  # 获取默认的嵌入函数实例


class ChromaDB_VectorStore(VannaBase):  # 定义 ChromaDB_VectorStore 类,继承自 VannaBase
    def __init__(self, config=None):  # 初始化函数,接受可选的配置参数
        VannaBase.__init__(self, config=config)  # 调用父类的初始化方法
        if config is None:  # 如果未提供配置,使用空字典作为默认配置
            config = {}

        path = config.get("path", ".")  # 获取配置中的路径,默认为当前目录
        self.embedding_function = config.get("embedding_function", default_ef)  # 获取嵌入函数,默认为 default_ef
        curr_client = config.get("client", "persistent")  # 获取客户端类型,默认为持久化客户端
        collection_metadata = config.get("collection_metadata", None)  # 获取集合元数据
        self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))  # 获取 SQL 查询结果数量,默认为 10
        self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))  # 获取文档查询结果数量,默认为 10
        self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))  # 获取 DDL 查询结果数量,默认为 10

        if curr_client == "persistent":  # 如果客户端类型是持久化的
            self.chroma_client = chromadb.PersistentClient(
                path=path, settings=Settings(anonymized_telemetry=False)
            )
        elif curr_client == "in-memory":  # 如果客户端类型是内存中的
            self.chroma_client = chromadb.EphemeralClient(
                settings=Settings(anonymized_telemetry=False)
            )
        elif isinstance(curr_client, chromadb.api.client.Client):  # 如果直接提供了客户端实例
            self.chroma_client = curr_client  # 使用提供的客户端实例
        else:
            raise ValueError(f"Unsupported client was set in config: {curr_client}")  # 如果客户端类型不支持,抛出异常

        # 获取或创建 documentation 集合
        self.documentation_collection = self.chroma_client.get_or_create_collection(
            name="documentation",
            embedding_function=self.embedding_function,
            metadata=collection_metadata,
        )
        # 获取或创建 ddl 集合
        self.ddl_collection = self.chroma_client.get_or_create_collection(
            name="ddl",
            embedding_function=self.embedding_function,
            metadata=collection_metadata,
        )
        # 获取或创建 sql 集合
        self.sql_collection = self.chroma_client.get_or_create_collection(
            name="sql",
            embedding_function=self.embedding_function,
            metadata=collection_metadata,
        )

    def generate_embedding(self, data: str, **kwargs) -> List[float]:  # 生成嵌入
        embedding = self.embedding_function([data])  # 调用嵌入函数生成嵌入
        if len(embedding) == 1:  # 如果生成的嵌入只有一个
            return embedding[0]  # 返回该嵌入
        return embedding  # 否则返回整个嵌入列表

    def add_question_sql(self, question: str, sql: str, **kwargs) -> str:  # 添加 SQL 问题
        question_sql_json = json.dumps(
            {
                "question": question,
                "sql": sql,
            },
            ensure_ascii=False,
        )  # 将问题和 SQL 转换为 JSON 字符串
        id = deterministic_uuid(question_sql_json) + "-sql"  # 生成唯一 ID,并添加 "-sql" 后缀
        self.sql_collection.add(
            documents=question_sql_json,  # 添加文档
            embeddings=self.generate_embedding(question_sql_json),  # 生成并添加嵌入
            ids=id,  # 添加 ID
        )
        return id  # 返回 ID

    def add_ddl(self, ddl: str, **kwargs) -> str:  # 添加 DDL
        id = deterministic_uuid(ddl) + "-ddl"  # 生成唯一 ID,并添加 "-ddl" 后缀
        self.ddl_collection.add(
            documents=ddl,  # 添加文档
            embeddings=self.generate_embedding(ddl),  # 生成并添加嵌入
            ids=id,  # 添加 ID
        )
        return id  # 返回 ID

    def add_documentation(self, documentation: str, **kwargs) -> str:  # 添加文档
        id = deterministic_uuid(documentation) + "-doc"  # 生成唯一 ID,并添加 "-doc" 后缀
        self.documentation_collection.add(
            documents=documentation,  # 添加文档
            embeddings=self.generate_embedding(documentation),  # 生成并添加嵌入
            ids=id,  # 添加 ID
        )
        return id  # 返回 ID

    def get_training_data(self, **kwargs) -> pd.DataFrame:  # 获取训练数据
        sql_data = self.sql_collection.get()  # 获取 SQL 集合中的数据
        df = pd.DataFrame()  # 创建一个空的 DataFrame

        if sql_data is not None:  # 如果 SQL 数据不为空
            documents = [json.loads(doc) for doc in sql_data["documents"]]  # 解析 JSON 文档
            ids = sql_data["ids"]  # 获取文档 ID
            df_sql = pd.DataFrame(
                {
                    "id": ids,
                    "question": [doc["question"] for doc in documents],
                    "content": [doc["sql"] for doc in documents],
                }
            )  # 创建 SQL 数据的 DataFrame
            df_sql["training_data_type"] = "sql"  # 添加数据类型列
            df = pd.concat([df, df_sql])  # 将 SQL 数据合并到主 DataFrame

        ddl_data = self.ddl_collection.get()  # 获取 DDL 集合中的数据
        if ddl_data is not None:  # 如果 DDL 数据不为空
            documents = [doc for doc in ddl_data["documents"]]  # 获取文档
            ids = ddl_data["ids"]  # 获取文档 ID
            df_ddl = pd.DataFrame(
                {
                    "id": ids,
                    "question": [None for doc in documents],
                    "content": [doc for doc in documents],
                }
            )  # 创建 DDL 数据的 DataFrame
            df_ddl["training_data_type"] = "ddl"  # 添加数据类型列
            df = pd.concat([df, df_ddl])  # 将 DDL 数据合并到主 DataFrame

        doc_data = self.documentation_collection.get()  # 获取文档集合中的数据
        if doc_data is not None:  # 如果文档数据不为空
            documents = [doc for doc in doc_data["documents"]]  # 获取文档
            ids = doc_data["ids"]  # 获取文档 ID
            df_doc = pd.DataFrame(
                {
                    "id": ids,
                    "question": [None for doc in documents],
                    "content": [doc for doc in documents],
                }
            )  # 创建文档数据的 DataFrame
            df_doc["training_data_type"] = "documentation"  # 添加数据类型列
            df = pd.concat([df, df_doc])  # 将文档数据合并到主 DataFrame

        return df  # 返回最终的 DataFrame

    def remove_training_data(self, id: str, **kwargs) -> bool:  # 删除训练数据
        if id.endswith("-sql"):  # 如果 ID 以 "-sql" 结尾
            self.sql_collection.delete(ids=id)  # 删除 SQL 集合中的该 ID
            return True
        elif id.endswith("-ddl"):  # 如果 ID 以 "-ddl" 结尾
            self.ddl_collection.delete(ids=id)  # 删除 DDL 集合中的该 ID
            return True
        elif id.endswith("-doc"):  # 如果 ID 以 "-doc" 结尾
            self.documentation_collection.delete(ids=id)  # 删除文档集合中的该 ID
            return True
        else:
            return False  # 如果 ID 不符合上述任何条件,返回 False

    def remove_collection(self, collection_name: str) -> bool:  # 删除集合
        """
        This function can reset the collection to empty state.

        Args:
            collection_name (str): sql or ddl or documentation

        Returns:


            bool: True if collection is deleted, False otherwise
        """
        if collection_name == "sql":  # 如果集合名称是 "sql"
            self.chroma_client.delete_collection(name="sql")  # 删除 sql 集合
            self.sql_collection = self.chroma_client.get_or_create_collection(
                name="sql", embedding_function=self.embedding_function
            )  # 重新创建 sql 集合
            return True
        elif collection_name == "ddl":  # 如果集合名称是 "ddl"
            self.chroma_client.delete_collection(name="ddl")  # 删除 ddl 集合
            self.ddl_collection = self.chroma_client.get_or_create_collection(
                name="ddl", embedding_function=self.embedding_function
            )  # 重新创建 ddl 集合
            return True
        elif collection_name == "documentation":  # 如果集合名称是 "documentation"
            self.chroma_client.delete_collection(name="documentation")  # 删除 documentation 集合
            self.documentation_collection = self.chroma_client.get_or_create_collection(
                name="documentation", embedding_function=self.embedding_function
            )  # 重新创建 documentation 集合
            return True
        else:
            return False  # 如果集合名称不符合上述任何条件,返回 False

    @staticmethod
    def _extract_documents(query_results) -> list:  # 静态方法:从查询结果中提取文档
        """
        Static method to extract the documents from the results of a query.

        Args:
            query_results (pd.DataFrame): The dataframe to use.

        Returns:
            List[str] or None: The extracted documents, or an empty list or
            single document if an error occurred.
        """
        if query_results is None:  # 如果查询结果为空
            return []  # 返回空列表

        if "documents" in query_results:  # 如果查询结果包含 "documents" 字段
            documents = query_results["documents"]

            if len(documents) == 1 and isinstance(documents[0], list):  # 如果只有一个文档并且是列表类型
                try:
                    documents = [json.loads(doc) for doc in documents[0]]  # 尝试解析 JSON 文档
                except Exception as e:  # 如果解析失败
                    return documents[0]  # 返回原始文档

            return documents  # 返回文档列表

    def get_similar_question_sql(self, question: str, **kwargs) -> list:  # 获取类似问题的 SQL
        return ChromaDB_VectorStore._extract_documents(
            self.sql_collection.query(
                query_texts=[question],  # 查询文本
                n_results=self.n_results_sql,  # 查询结果数量
            )
        )

    def get_related_ddl(self, question: str, **kwargs) -> list:  # 获取相关的 DDL
        return ChromaDB_VectorStore._extract_documents(
            self.ddl_collection.query(
                query_texts=[question],  # 查询文本
                n_results=self.n_results_ddl,  # 查询结果数量
            )
        )

    def get_related_documentation(self, question: str, **kwargs) -> list:  # 获取相关的文档
        return ChromaDB_VectorStore._extract_documents(
            self.documentation_collection.query(
                query_texts=[question],  # 查询文本
                n_results=self.n_results_documentation,  # 查询结果数量
            )
        )
  • 7
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值