vanna学习日志(四)


前言

vanna可实现自然语言转SQL,尝试本地部署vanna对接数据库,将自然语言转成标准的SQL对数据库进行查询。本文先对vanna源码进行分析,该部分内容为接入各类ai方法。

一、vanna源码分析

google

gemini_chat.py

这段代码定义了一个 GoogleGeminiChat 类,继承自 VannaBase

  • 这个类主要用于与 Google 的生成式 AI 模型进行交互,特别是用于聊天应用。
  • 通过配置 API 密钥和模型名称,可以灵活地使用不同的生成模型。
  • 提供了简单的消息处理方法,可以根据需要进行扩展。
import os
from ..base import VannaBase  # 从父模块导入 VannaBase 类

class GoogleGeminiChat(VannaBase):
    def __init__(self, config=None):
        VannaBase.__init__(self, config=config)  # 调用父类的构造函数

        # 默认的温度值,可以通过 config 覆盖
        self.temperature = 0.7

        # 如果配置中包含 temperature,则覆盖默认值
        if "temperature" in config:
            self.temperature = config["temperature"]

        # 设置模型名称,如果配置中没有指定,则使用默认值 "gemini-1.0-pro"
        if "model_name" in config:
            model_name = config["model_name"]
        else:
            model_name = "gemini-1.0-pro"

        self.google_api_key = None  # 初始化 API 密钥变量

        # 如果配置中提供了 api_key 或环境变量中有 GOOGLE_API_KEY,则使用它
        if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
            """
            如果 Google api_key 通过配置提供
            或设置为环境变量,则分配它。
            """
            import google.generativeai as genai  # 导入 google.generativeai 库

            genai.configure(api_key=config["api_key"])  # 使用提供的 API 密钥进行配置
            self.chat_model = genai.GenerativeModel(model_name)  # 初始化生成模型
        else:
            # 使用 VertexAI 进行身份验证
            from vertexai.preview.generative_models import GenerativeModel  # 导入 VertexAI 的生成模型类
            self.chat_model = GenerativeModel("gemini-pro")  # 初始化生成模型

    def system_message(self, message: str) -> any:
        return message  # 返回系统消息

    def user_message(self, message: str) -> any:
        return message  # 返回用户消息

    def assistant_message(self, message: str) -> any:
        return message  # 返回助手消息

    def submit_prompt(self, prompt, **kwargs) -> str:
        # 使用生成模型生成内容
        response = self.chat_model.generate_content(
            prompt,
            generation_config={
                "temperature": self.temperature,  # 使用配置的温度值
            },
        )
        return response.text  # 返回生成的文本
功能和作用
  1. 类的初始化

    • 初始化时调用父类 VannaBase 的构造函数。
    • 设置默认的温度值 self.temperature,可以通过配置覆盖。
    • 设置模型名称 model_name,可以通过配置覆盖,默认使用 "gemini-1.0-pro"
    • 检查是否提供了 API 密钥,如果提供了,则使用 google.generativeai 库进行配置并初始化生成模型。如果没有提供 API 密钥,则使用 VertexAI 进行身份验证并初始化生成模型。
  2. 消息处理方法

    • system_messageuser_messageassistant_message 方法都是简单地返回传入的消息。这些方法可以在实际应用中进行扩展,以处理不同类型的消息。
  3. 提交提示

    • submit_prompt 方法使用生成模型生成内容。它接受一个提示(prompt)并生成相应的文本。生成的配置包括温度值。

hf

hf.py

这段代码定义了一个 Hf 类,继承自 VannaBase

  • 这个类主要用于与 Hugging Face 的生成模型进行交互,特别是用于生成 SQL 查询。
  • 通过配置模型名称,可以灵活地使用不同的生成模型。
  • 提供了简单的消息处理方法和 SQL 提取方法,可以根据需要进行扩展。
import re  # 导入正则表达式模块
from transformers import AutoTokenizer, AutoModelForCausalLM  # 从 transformers 库中导入自动分词器和因果语言模型

from ..base import VannaBase  # 从父模块导入 VannaBase 类

class Hf(VannaBase):
    def __init__(self, config=None):
        # 从配置中获取模型名称,例如 "meta-llama/Meta-Llama-3-8B-Instruct"
        model_name = self.config.get("model_name", None)

        # 从预训练模型中加载分词器
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # 从预训练模型中加载因果语言模型,设置数据类型和设备映射为自动
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto",
            device_map="auto",
        )

    def system_message(self, message: str) -> any:
        # 返回包含角色和内容的字典,表示系统消息
        return {"role": "system", "content": message}

    def user_message(self, message: str) -> any:
        # 返回包含角色和内容的字典,表示用户消息
        return {"role": "user", "content": message}

    def assistant_message(self, message: str) -> any:
        # 返回包含角色和内容的字典,表示助手消息
        return {"role": "assistant", "content": message}

    def extract_sql_query(self, text):
        """
        提取第一个 SQL 语句,从 'select' 开始,不区分大小写,
        匹配到第一个分号、三个反引号或字符串结尾,
        并移除提取字符串中的三个反引号(如果存在)。

        参数:
        - text (str): 要搜索 SQL 语句的字符串。

        返回:
        - str: 找到的第一个 SQL 语句,移除三个反引号后的结果,如果没有匹配则返回空字符串。
        """
        # 正则表达式模式,用于查找 'select'(忽略大小写)并捕获到分号、三个反引号或字符串结尾
        pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)

        # 搜索匹配项
        match = pattern.search(text)
        if match:
            # 如果存在匹配项,移除匹配字符串中的三个反引号
            return match.group(0).replace("```", "")
        else:
            # 如果没有匹配项,返回输入字符串
            return text

    def generate_sql(self, question: str, **kwargs) -> str:
        # 使用父类的 generate_sql 方法生成 SQL 语句
        sql = super().generate_sql(question, **kwargs)

        # 替换字符串中的 "\_" 为 "_"
        sql = sql.replace("\\_", "_")

        # 替换字符串中的 "\" 为 ""
        sql = sql.replace("\\", "")

        # 提取并返回 SQL 查询
        return self.extract_sql_query(sql)

    def submit_prompt(self, prompt, **kwargs) -> str:
        # 使用分词器对提示进行处理,添加生成提示并返回张量格式
        input_ids = self.tokenizer.apply_chat_template(
            prompt, add_generation_prompt=True, return_tensors="pt"
        ).to(self.model.device)

        # 使用模型生成输出,设置最大新标记数、结束标记 ID、是否进行采样、温度和 top-p 参数
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=512,
            eos_token_id=self.tokenizer.eos_token_id,
            do_sample=True,
            temperature=1,
            top_p=0.9,
        )
        # 提取生成的响应,跳过输入部分
        response = outputs[0][input_ids.shape[-1] :]

        # 解码生成的响应,跳过特殊标记
        response = self.tokenizer.decode(response, skip_special_tokens=True)

        # 记录响应
        self.log(response)

        # 返回响应
        return response
功能和作用
  1. 类的初始化

    • 初始化时从配置中获取模型名称,并加载相应的分词器和因果语言模型。
    • 设置模型的 torch_dtypedevice_mapauto,以便自动调整数据类型和设备。
  2. 消息处理方法

    • system_messageuser_messageassistant_message 方法返回包含角色和内容的字典,用于表示不同类型的消息。
  3. 提取 SQL 查询

    • extract_sql_query 方法使用正则表达式从输入文本中提取第一个 SQL 语句,匹配到分号、三个反引号或字符串结尾,并移除提取字符串中的三个反引号。
  4. 生成 SQL 查询

    • generate_sql 方法首先调用父类的 generate_sql 方法生成 SQL 语句,然后替换字符串中的特定字符,并使用 extract_sql_query 方法提取 SQL 查询。
  5. 提交提示

    • submit_prompt 方法使用分词器对提示进行处理,并使用模型生成响应。生成的响应经过解码和记录后返回。

marqo

marqo.py

  • 这个类主要用于与 Marqo 向量存储进行交互,管理和搜索与 SQL 查询、DDL 以及文档相关的数据。
  • 提供了一套方法,用于添加、检索和删除训练数据,并支持通过相似性搜索获取相关的 SQL、DDL 和文档。
import uuid  # 导入用于生成唯一标识符的模块
import marqo  # 导入 Marqo 库,用于处理向量存储
import pandas as pd  # 导入 pandas 库,用于数据处理

from ..base import VannaBase  # 从父模块导入 VannaBase 类

class Marqo_VectorStore(VannaBase):
    def __init__(self, config=None):
        # 调用父类的构造函数进行初始化
        VannaBase.__init__(self, config=config)

        # 检查配置中是否包含 marqo_url,否则使用默认值
        if config is not None and "marqo_url" in config:
            marqo_url = config["marqo_url"]
        else:
            marqo_url = "http://localhost:8882"

        # 检查配置中是否包含 marqo_model,否则使用默认模型名称
        if config is not None and "marqo_model" in config:
            marqo_model = config["marqo_model"]
        else:
            marqo_model = "hf/all_datasets_v4_MiniLM-L6"

        # 创建 Marqo 客户端
        self.mq = marqo.Client(url=marqo_url)

        # 创建三个索引:vanna-sql、vanna-ddl、vanna-doc
        for index in ["vanna-sql", "vanna-ddl", "vanna-doc"]:
            try:
                # 尝试创建索引
                self.mq.create_index(index, model=marqo_model)
            except Exception as e:
                # 如果索引已经存在,则捕获异常并打印错误信息
                print(e)
                print(f"Marqo index {index} already exists")
                pass

    def generate_embedding(self, data: str, **kwargs) -> list[float]:
        # Marqo 不需要生成嵌入
        pass

    def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
        # 生成唯一标识符并添加 "-sql" 后缀
        id = str(uuid.uuid4()) + "-sql"
        # 创建包含问题和 SQL 的字典
        question_sql_dict = {
            "question": question,
            "sql": sql,
            "_id": id,
        }

        # 将文档添加到 "vanna-sql" 索引中
        self.mq.index("vanna-sql").add_documents(
            [question_sql_dict],
            tensor_fields=["question", "sql"],
        )

        return id

    def add_ddl(self, ddl: str, **kwargs) -> str:
        # 生成唯一标识符并添加 "-ddl" 后缀
        id = str(uuid.uuid4()) + "-ddl"
        # 创建包含 DDL 的字典
        ddl_dict = {
            "ddl": ddl,
            "_id": id,
        }

        # 将文档添加到 "vanna-ddl" 索引中
        self.mq.index("vanna-ddl").add_documents(
            [ddl_dict],
            tensor_fields=["ddl"],
        )
        return id

    def add_documentation(self, documentation: str, **kwargs) -> str:
        # 生成唯一标识符并添加 "-doc" 后缀
        id = str(uuid.uuid4()) + "-doc"
        # 创建包含文档的字典
        doc_dict = {
            "doc": documentation,
            "_id": id,
        }

        # 将文档添加到 "vanna-doc" 索引中
        self.mq.index("vanna-doc").add_documents(
            [doc_dict],
            tensor_fields=["doc"],
        )
        return id

    def get_training_data(self, **kwargs) -> pd.DataFrame:
        data = []  # 初始化一个空列表用于存储数据

        # 从 "vanna-doc" 索引中检索文档
        for hit in self.mq.index("vanna-doc").search("", limit=1000)["hits"]:
            data.append(
                {
                    "id": hit["_id"],
                    "training_data_type": "documentation",
                    "question": "",
                    "content": hit["doc"],
                }
            )

        # 从 "vanna-ddl" 索引中检索文档
        for hit in self.mq.index("vanna-ddl").search("", limit=1000)["hits"]:
            data.append(
                {
                    "id": hit["_id"],
                    "training_data_type": "ddl",
                    "question": "",
                    "content": hit["ddl"],
                }
            )

        # 从 "vanna-sql" 索引中检索文档
        for hit in self.mq.index("vanna-sql").search("", limit=1000)["hits"]:
            data.append(
                {
                    "id": hit["_id"],
                    "training_data_type": "sql",
                    "question": hit["question"],
                    "content": hit["sql"],
                }
            )

        # 将数据转换为 DataFrame 并返回
        df = pd.DataFrame(data)

        return df

    def remove_training_data(self, id: str, **kwargs) -> bool:
        # 根据 ID 后缀确定要删除的索引中的文档
        if id.endswith("-sql"):
            self.mq.index("vanna-sql").delete_documents(ids=[id])
            return True
        elif id.endswith("-ddl"):
            self.mq.index("vanna-ddl").delete_documents(ids=[id])
            return True
        elif id.endswith("-doc"):
            self.mq.index("vanna-doc").delete_documents(ids=[id])
            return True
        else:
            return False

    @staticmethod
    def _extract_documents(data) -> list:
        # 检查数据中是否包含 'hits' 键且其值是否为列表
        if "hits" in data and isinstance(data["hits"], list):
            # 如果 'hits' 列表为空,则返回空列表
            if len(data["hits"]) == 0:
                return []

            # 如果 'hits' 中包含 "doc" 键,则返回其值
            if "doc" in data["hits"][0]:
                return [hit["doc"] for hit in data["hits"]]

            # 如果 'hits' 中包含 "ddl" 键,则返回其值
            if "ddl" in data["hits"][0]:
                return [hit["ddl"] for hit in data["hits"]]

            # 否则,返回所有命中的项目
            return [
                {key: value for key, value in hit.items() if not key.startswith("_")}
                for hit in data["hits"]
            ]
        else:
            # 如果 'hits' 不存在或不是列表,则返回空列表
            return []

    def get_similar_question_sql(self, question: str, **kwargs) -> list:
        # 从 "vanna-sql" 索引中搜索相似的问题 SQL,并提取文档
        return Marqo_VectorStore._extract_documents(
            self.mq.index("vanna-sql").search(question)
        )

    def get_related_ddl(self, question: str, **kwargs) -> list:
        # 从 "vanna-ddl" 索引中搜索相关的 DDL,并提取文档
        return Marqo_VectorStore._extract_documents(
            self.mq.index("vanna-ddl").search(question)
        )

    def get_related_documentation(self, question: str, **kwargs) -> list:
        # 从 "vanna-doc" 索引中搜索相关的文档,并提取文档
        return Marqo_VectorStore._extract_documents(
            self.mq.index("vanna-doc").search(question)
        )
功能和作用
  1. 类的初始化

    • 初始化时调用父类的构造函数。
    • 从配置中读取 marqo_urlmarqo_model,如果未提供则使用默认值。
    • 创建 Marqo 客户端并尝试创建三个索引:vanna-sqlvanna-ddlvanna-doc
  2. 生成嵌入

    • generate_embedding 方法目前没有实现,因为 Marqo 不需要生成嵌入。
  3. 添加问题和 SQL

    • add_question_sql 方法生成一个唯一标识符,并将问题和 SQL 添加到 vanna-sql 索引中。
  4. 添加 DDL

    • add_ddl 方法生成一个唯一标识符,并将 DDL 添加到 vanna-ddl 索引中。
  5. 添加文档

    • add_documentation 方法生成一个唯一标识符,并将文档添加到 vanna-doc 索引中。
  6. 获取训练数据

    • get_training_data 方法从三个索引中检索文档并转换为 pandas DataFrame。
  7. 删除训练数据

    • remove_training_data 方法根据文档 ID 后缀确定要删除的索引中的文档。
  8. 静态方法提取文档

    • _extract_documents 静态方法从搜索结果中提取文档。
  9. 获取相似问题的 SQL

    • get_similar_question_sql 方法从 vanna-sql 索引中搜索相似的问题 SQL,并提取文档。
  10. 获取相关的 DDL

  • get_related_ddl 方法从 vanna-ddl 索引中搜索相关的 DDL,并提取文档。
  1. 获取相关文档
  • get_related_documentation 方法从 vanna-doc 索引中搜索相关的文档,并提取文档。

mistral

mistral.py

  • 这个类主要用于与 Mistral AI 服务进行交互,处理聊天消息并生成相应的 SQL 查询。
  • 提供了一套方法,用于构建系统消息、用户消息和助手消息,并支持发送聊天请求和处理响应。
from mistralai.client import MistralClient  # 从 Mistral 库导入 MistralClient
from mistralai.models.chat_completion import ChatMessage  # 从 Mistral 库导入 ChatMessage 模型

from ..base import VannaBase  # 从父模块导入 VannaBase 类


class Mistral(VannaBase):
    def __init__(self, config=None):
        # 如果没有提供配置,抛出 ValueError 异常
        if config is None:
            raise ValueError(
                "For Mistral, config must be provided with an api_key and model"
            )

        # 如果配置中不包含 api_key,抛出 ValueError 异常
        if "api_key" not in config:
            raise ValueError("config must contain a Mistral api_key")

        # 如果配置中不包含 model,抛出 ValueError 异常
        if "model" not in config:
            raise ValueError("config must contain a Mistral model")

        # 从配置中获取 api_key 和 model
        api_key = config["api_key"]
        model = config["model"]
        # 创建 Mistral 客户端实例
        self.client = MistralClient(api_key=api_key)
        # 设置模型名称
        self.model = model

    def system_message(self, message: str) -> any:
        # 返回一个系统消息对象
        return ChatMessage(role="system", content=message)

    def user_message(self, message: str) -> any:
        # 返回一个用户消息对象
        return ChatMessage(role="user", content=message)

    def assistant_message(self, message: str) -> any:
        # 返回一个助手消息对象
        return ChatMessage(role="assistant", content=message)

    def generate_sql(self, question: str, **kwargs) -> str:
        # 使用父类的方法生成 SQL 查询
        sql = super().generate_sql(question, **kwargs)

        # 将 "\_" 替换为 "_"
        sql = sql.replace("\\_", "_")

        return sql

    def submit_prompt(self, prompt, **kwargs) -> str:
        # 使用 Mistral 客户端发送聊天请求
        chat_response = self.client.chat(
            model=self.model,
            messages=prompt,
        )

        # 返回聊天响应中的消息内容
        return chat_response.choices[0].message.content
功能和作用
  1. 类的初始化

    • 初始化时检查配置是否包含 api_keymodel,如果缺少任意一个则抛出 ValueError 异常。
    • 创建 Mistral 客户端实例,并设置模型名称。
  2. 系统消息

    • system_message 方法创建并返回一个系统消息对象。
  3. 用户消息

    • user_message 方法创建并返回一个用户消息对象。
  4. 助手消息

    • assistant_message 方法创建并返回一个助手消息对象。
  5. 生成 SQL 查询

    • generate_sql 方法调用父类的方法生成 SQL 查询,然后替换其中的 “_” 为 “_” 并返回最终的 SQL 查询。
  6. 提交提示

    • submit_prompt 方法使用 Mistral 客户端发送聊天请求,并返回聊天响应中的消息内容。

mock

embedding.py

这段代码定义了一个名为 MockEmbedding 的类,继承自 VannaBase

  • MockEmbedding 类包含一个构造函数和一个 generate_embedding 方法,后者返回一个固定的浮点数列表。
  • 该类主要用于测试和开发阶段,提供一个简单的嵌入生成实现。
from typing import List  # 从 typing 模块导入 List 类型,用于类型注解

from ..base import VannaBase  # 从上一级目录的 base 模块导入 VannaBase 类


class MockEmbedding(VannaBase):  # 定义一个名为 MockEmbedding 的类,继承自 VannaBase
    def __init__(self, config=None):  # 定义类的构造函数,接受一个可选的配置参数 config
        pass  # 目前构造函数不做任何操作

    def generate_embedding(self, data: str, **kwargs) -> List[float]:  # 定义一个名为 generate_embedding 的方法,接受一个字符串参数 data 和其他可选参数,返回一个浮点数列表
        return [1.0, 2.0, 3.0, 4.0, 5.0]  # 返回一个固定的浮点数列表
功能和作用
  1. 导入 List 类型

    • typing 模块导入 List 类型,用于类型注解,指定 generate_embedding 方法返回值的类型。
  2. 导入 VannaBase

    • 从上一级目录的 base 模块导入 VannaBase 类。VannaBase 类可能是所有具体实现的基类,提供一些基本的功能和接口。
  3. 定义 MockEmbedding

    • MockEmbedding 类继承自 VannaBase,用于模拟嵌入生成的功能,通常在测试或开发阶段使用。
  4. 构造函数 __init__

    • 定义类的构造函数,接受一个可选的配置参数 config。目前构造函数不做任何实际操作,仅包含一个 pass 语句。
  5. generate_embedding 方法

    • 定义一个名为 generate_embedding 的方法,接受一个字符串参数 data 和其他可选参数,返回一个浮点数列表。
    • 方法的实现简单地返回一个固定的浮点数列表 [1.0, 2.0, 3.0, 4.0, 5.0],模拟生成的嵌入向量。

llm.py

这段代码定义了一个名为 MockLLM 的类,继承自 VannaBase

  • MockLLM 类包含一个构造函数和多个消息创建方法,以及一个提交提示的方法,后者返回一个固定的字符串响应。
  • 该类主要用于测试和开发阶段,提供一个简单的大语言模型模拟实现。
from ..base import VannaBase  # 从上一级目录的 base 模块导入 VannaBase 类


class MockLLM(VannaBase):  # 定义一个名为 MockLLM 的类,继承自 VannaBase
    def __init__(self, config=None):  # 定义类的构造函数,接受一个可选的配置参数 config
        pass  # 目前构造函数不做任何操作

    def system_message(self, message: str) -> any:  # 定义一个名为 system_message 的方法,接受一个字符串参数 message,返回一个任意类型的值
        return {"role": "system", "content": message}  # 返回一个包含角色和内容的字典,角色为 "system"

    def user_message(self, message: str) -> any:  # 定义一个名为 user_message 的方法,接受一个字符串参数 message,返回一个任意类型的值
        return {"role": "user", "content": message}  # 返回一个包含角色和内容的字典,角色为 "user"

    def assistant_message(self, message: str) -> any:  # 定义一个名为 assistant_message 的方法,接受一个字符串参数 message,返回一个任意类型的值
        return {"role": "assistant", "content": message}  # 返回一个包含角色和内容的字典,角色为 "assistant"

    def submit_prompt(self, prompt, **kwargs) -> str:  # 定义一个名为 submit_prompt 的方法,接受一个参数 prompt 和其他可选参数,返回一个字符串
        return "Mock LLM response"  # 返回一个固定的字符串 "Mock LLM response"
代码功能和作用
  1. 导入 VannaBase

    • 从上一级目录的 base 模块导入 VannaBase 类。VannaBase 类可能是所有具体实现的基类,提供一些基本的功能和接口。
  2. 定义 MockLLM

    • MockLLM 类继承自 VannaBase,用于模拟大语言模型(LLM)的功能,通常在测试或开发阶段使用。
  3. 构造函数 __init__

    • 定义类的构造函数,接受一个可选的配置参数 config。目前构造函数不做任何实际操作,仅包含一个 pass 语句。
  4. system_message 方法

    • 定义一个名为 system_message 的方法,接受一个字符串参数 message,返回一个包含角色和内容的字典,角色为 “system”。
    • 该方法用于创建系统消息的结构。
  5. user_message 方法

    • 定义一个名为 user_message 的方法,接受一个字符串参数 message,返回一个包含角色和内容的字典,角色为 “user”。
    • 该方法用于创建用户消息的结构。
  6. assistant_message 方法

    • 定义一个名为 assistant_message 的方法,接受一个字符串参数 message,返回一个包含角色和内容的字典,角色为 “assistant”。
    • 该方法用于创建助手消息的结构。
  7. submit_prompt 方法

    • 定义一个名为 submit_prompt 的方法,接受一个参数 prompt 和其他可选参数,返回一个字符串。
    • 该方法返回一个固定的字符串 “Mock LLM response”,模拟大语言模型的响应。

vectordb.py

MockVectorDB 类模拟了一个简单的向量数据库,实现了一些基本的数据库操作和数据处理功能。该类主要用于测试和开发阶段,提供一个简单的实现,而无需依赖实际的数据库操作。

  • 数据处理:通过 pandas DataFrame 返回训练数据,便于数据的处理和操作。
  • 接口实现:实现了一些基本的接口方法,如获取相关的 DDL、文档和类似问题及 SQL,这些方法在测试和开发阶段可以返回空值或固定值。
import pandas as pd  # 导入 pandas 库,用于处理数据

from ..base import VannaBase  # 从上一级目录的 base 模块导入 VannaBase 类


class MockVectorDB(VannaBase):  # 定义一个名为 MockVectorDB 的类,继承自 VannaBase
    def __init__(self, config=None):  # 定义类的构造函数,接受一个可选的配置参数 config
        pass  # 目前构造函数不做任何操作

    def _get_id(self, value: str, **kwargs) -> str:  # 定义一个私有方法 _get_id,接受一个字符串参数 value 和其他可选参数
        # 将值进行哈希处理并返回 ID
        return str(hash(value))  # 返回值的哈希值转换成字符串形式

    def add_ddl(self, ddl: str, **kwargs) -> str:  # 定义一个方法 add_ddl,接受一个字符串参数 ddl 和其他可选参数
        return self._get_id(ddl)  # 调用 _get_id 方法并返回其结果

    def add_documentation(self, doc: str, **kwargs) -> str:  # 定义一个方法 add_documentation,接受一个字符串参数 doc 和其他可选参数
        return self._get_id(doc)  # 调用 _get_id 方法并返回其结果

    def add_question_sql(self, question: str, sql: str, **kwargs) -> str:  # 定义一个方法 add_question_sql,接受一个字符串参数 question 和 sql 以及其他可选参数
        return self._get_id(question)  # 调用 _get_id 方法并返回其结果

    def get_related_ddl(self, question: str, **kwargs) -> list:  # 定义一个方法 get_related_ddl,接受一个字符串参数 question 和其他可选参数
        return []  # 返回一个空列表

    def get_related_documentation(self, question: str, **kwargs) -> list:  # 定义一个方法 get_related_documentation,接受一个字符串参数 question 和其他可选参数
        return []  # 返回一个空列表

    def get_similar_question_sql(self, question: str, **kwargs) -> list:  # 定义一个方法 get_similar_question_sql,接受一个字符串参数 question 和其他可选参数
        return []  # 返回一个空列表

    def get_training_data(self, **kwargs) -> pd.DataFrame:  # 定义一个方法 get_training_data,接受其他可选参数,返回一个 pandas DataFrame
        # 返回一个包含训练数据的 DataFrame
        return pd.DataFrame({'id': {0: '19546-ddl',  # 训练数据 ID
          1: '91597-sql',
          2: '133976-sql',
          3: '59851-doc',
          4: '73046-sql'},
         'training_data_type': {0: 'ddl',  # 训练数据类型
          1: 'sql',
          2: 'sql',
          3: 'documentation',
          4: 'sql'},
         'question': {0: None,  # 问题
          1: 'What are the top selling genres?',
          2: 'What are the low 7 artists by sales?',
          3: None,
          4: 'What is the total sales for each customer?'},
         'content': {0: 'CREATE TABLE [Invoice]\n(\n    [InvoiceId] INTEGER  NOT NULL,\n    [CustomerId] INTEGER  NOT NULL,\n    [InvoiceDate] DATETIME  NOT NULL,\n    [BillingAddress] NVARCHAR(70),\n    [BillingCity] NVARCHAR(40),\n    [BillingState] NVARCHAR(40),\n    [BillingCountry] NVARCHAR(40),\n    [BillingPostalCode] NVARCHAR(10),\n    [Total] NUMERIC(10,2)  NOT NULL,\n    CONSTRAINT [PK_Invoice] PRIMARY KEY  ([InvoiceId]),\n    FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)',  # 内容
          1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;',
          2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;',
          3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.',
          4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}})

    def remove_training_data(id: str, **kwargs) -> bool:  # 定义一个方法 remove_training_data,接受一个字符串参数 id 和其他可选参数
        return True  # 返回 True,表示删除成功
代码功能和作用
  1. 导入模块

    • pandas:用于数据处理和操作。
    • VannaBase:从上一级目录的 base 模块导入 VannaBase 类,作为基类。
  2. 定义 MockVectorDB

    • 继承自 VannaBase,模拟一个向量数据库的基本操作。
  3. 构造函数 __init__

    • 定义类的构造函数,目前不做任何实际操作,仅包含一个 pass 语句。
  4. _get_id 方法

    • 私有方法 _get_id,接受一个字符串参数 value,返回该字符串的哈希值作为 ID。
  5. add_ddl 方法

    • 接受一个 DDL 语句字符串,调用 _get_id 方法返回其哈希值作为 ID。
  6. add_documentation 方法

    • 接受一个文档字符串,调用 _get_id 方法返回其哈希值作为 ID。
  7. add_question_sql 方法

    • 接受一个问题和对应的 SQL 语句,调用 _get_id 方法返回问题的哈希值作为 ID。
  8. get_related_ddl 方法

    • 接受一个问题字符串,返回一个空列表,表示没有相关的 DDL。
  9. get_related_documentation 方法

    • 接受一个问题字符串,返回一个空列表,表示没有相关的文档。
  10. get_similar_question_sql 方法

    • 接受一个问题字符串,返回一个空列表,表示没有类似的问题和 SQL。
  11. get_training_data 方法

    • 返回一个包含训练数据的 pandas DataFrame,其中包含 ID、训练数据类型、问题和内容。
  12. remove_training_data 方法

    • 接受一个数据 ID,返回 True,表示成功删除训练数据。

openai

openai_chat.py

代码用途

OpenAI_Chat 类实现了与 OpenAI 模型的基本接口,提供了初始化、消息处理和提示提交等功能。该类用于与 OpenAI 模型进行通信,发送提示并接收响应,同时处理和返回响应中的文本内容。

  • 与 OpenAI 模型的通信OpenAI_Chat 类实现了与 OpenAI 模型的接口,通过 HTTP 请求与 OpenAI 模型进行通信,发送提示并接收响应。
  • 消息处理:提供系统消息、用户消息和助手消息的方法,用于生成与 OpenAI 模型通信的消息格式。
  • 提示提交和响应处理:检查和验证提示内容,计算 token 数量,并调用 OpenAI API 生成响应,同时处理和返回响应中的文本内容。
import os  # 导入 os 模块,用于与操作系统交互

from openai import OpenAI  # 从 openai 模块导入 OpenAI 类

from ..base import VannaBase  # 从上一级目录的 base 模块导入 VannaBase 类


class OpenAI_Chat(VannaBase):  # 定义一个名为 OpenAI_Chat 的类,继承自 VannaBase
    def __init__(self, client=None, config=None):  # 定义类的构造函数,接受可选参数 client 和 config
        VannaBase.__init__(self, config=config)  # 调用父类 VannaBase 的构造函数

        # 设置默认参数 - 可以通过 config 覆盖
        self.temperature = 0.7
        self.max_tokens = 500

        if "temperature" in config:
            self.temperature = config["temperature"]  # 如果 config 中包含 "temperature",则覆盖默认温度值

        if "max_tokens" in config:
            self.max_tokens = config["max_tokens"]  # 如果 config 中包含 "max_tokens",则覆盖默认最大 token 数

        if "api_type" in config:
            raise Exception(
                "Passing api_type is now deprecated. Please pass an OpenAI client instead."
            )  # 如果 config 中包含 "api_type",抛出异常

        if "api_base" in config:
            raise Exception(
                "Passing api_base is now deprecated. Please pass an OpenAI client instead."
            )  # 如果 config 中包含 "api_base",抛出异常

        if "api_version" in config:
            raise Exception(
                "Passing api_version is now deprecated. Please pass an OpenAI client instead."
            )  # 如果 config 中包含 "api_version",抛出异常

        if client is not None:
            self.client = client  # 如果提供了 client 参数,则将其设置为实例属性
            return

        if config is None and client is None:
            self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))  # 如果没有提供 config 和 client,从环境变量中获取 API 密钥并创建 OpenAI 客户端
            return

        if "api_key" in config:
            self.client = OpenAI(api_key=config["api_key"])  # 如果 config 中包含 "api_key",使用该密钥创建 OpenAI 客户端

    def system_message(self, message: str) -> any:  # 定义一个方法 system_message,接受一个字符串参数 message
        return {"role": "system", "content": message}  # 返回一个包含角色和内容的字典

    def user_message(self, message: str) -> any:  # 定义一个方法 user_message,接受一个字符串参数 message
        return {"role": "user", "content": message}  # 返回一个包含角色和内容的字典

    def assistant_message(self, message: str) -> any:  # 定义一个方法 assistant_message,接受一个字符串参数 message
        return {"role": "assistant", "content": message}  # 返回一个包含角色和内容的字典

    def submit_prompt(self, prompt, **kwargs) -> str:  # 定义一个方法 submit_prompt,接受一个参数 prompt 和其他可选参数
        if prompt is None:
            raise Exception("Prompt is None")  # 如果 prompt 为 None,抛出异常

        if len(prompt) == 0:
            raise Exception("Prompt is empty")  # 如果 prompt 为空,抛出异常

        # 计算消息日志中的 token 数量
        # 使用 4 作为每个 token 的近似字符数
        num_tokens = 0
        for message in prompt:
            num_tokens += len(message["content"]) / 4

        if kwargs.get("model", None) is not None:
            model = kwargs.get("model", None)  # 如果 kwargs 中提供了 model 参数,则使用该模型
            print(
                f"Using model {model} for {num_tokens} tokens (approx)"
            )
            response = self.client.chat.completions.create(
                model=model,
                messages=prompt,
                max_tokens=self.max_tokens,
                stop=None,
                temperature=self.temperature,
            )
        elif kwargs.get("engine", None) is not None:
            engine = kwargs.get("engine", None)  # 如果 kwargs 中提供了 engine 参数,则使用该引擎
            print(
                f"Using model {engine} for {num_tokens} tokens (approx)"
            )
            response = self.client.chat.completions.create(
                engine=engine,
                messages=prompt,
                max_tokens=self.max_tokens,
                stop=None,
                temperature=self.temperature,
            )
        elif self.config is not None and "engine" in self.config:
            print(
                f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
            )
            response = self.client.chat.completions.create(
                engine=self.config["engine"],
                messages=prompt,
                max_tokens=self.max_tokens,
                stop=None,
                temperature=self.temperature,
            )
        elif self.config is not None and "model" in self.config:
            print(
                f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
            )
            response = self.client.chat.completions.create(
                model=self.config["model"],
                messages=prompt,
                max_tokens=self.max_tokens,
                stop=None,
                temperature=self.temperature,
            )
        else:
            if num_tokens > 3500:
                model = "gpt-3.5-turbo-16k"  # 如果 token 数量超过 3500,使用 gpt-3.5-turbo-16k 模型
            else:
                model = "gpt-3.5-turbo"  # 否则使用 gpt-3.5-turbo 模型

            print(f"Using model {model} for {num_tokens} tokens (approx)")
            response = self.client.chat.completions.create(
                model=model,
                messages=prompt,
                max_tokens=self.max_tokens,
                stop=None,
                temperature=self.temperature,
            )

        # 查找包含文本的第一个响应(有些响应可能没有文本)
        for choice in response.choices:
            if "text" in choice:
                return choice.text

        # 如果没有找到包含文本的响应,返回第一个响应的内容(可能为空)
        return response.choices[0].message.content
功能和作用
  1. 导入模块

    • os:用于与操作系统交互,特别是获取环境变量。
    • OpenAI:用于与 OpenAI API 进行交互。
    • VannaBaseDependencyError:用于继承基类和处理依赖错误。
  2. 定义 OpenAI_Chat

    • 继承自 VannaBase,实现了与 OpenAI 模型的接口。
  3. 构造函数 __init__

    • 调用父类的构造函数并初始化一些默认参数。
    • 根据配置覆盖默认参数。
    • 检查并设置 OpenAI 客户端。
  4. 消息方法

    • system_messageuser_messageassistant_message 方法分别返回带有角色和内容的字典,用于与 OpenAI 模型的通信。
  5. submit_prompt 方法

    • 检查和验证提示内容。
    • 计算消息日志中的 token 数量。
    • 根据不同的配置和参数,调用 OpenAI API 生成响应。
    • 查找并返回包含文本的第一个响应。

openai_embeddings.py

这个类的主要功能是初始化 OpenAI 客户端并根据提供的配置生成嵌入向量。

from openai import OpenAI  # 从 openai 库导入 OpenAI 类,用于与 OpenAI API 交互

from ..base import VannaBase  # 从本地模块导入 VannaBase 类,这是一个基础类

# 定义一个新的类 OpenAI_Embeddings,它继承自 VannaBase
class OpenAI_Embeddings(VannaBase):
    def __init__(self, client=None, config=None):
        # 调用父类 VannaBase 的构造函数,并传递配置参数
        VannaBase.__init__(self, config=config)

        # 如果提供了 client 参数,则使用这个参数初始化 self.client
        if client is not None:
            self.client = client
            return  # 结束构造函数的执行

        # 如果 self.client 已经存在(通过父类初始化),则结束构造函数
        if self.client is not None:
            return

        # 如果没有提供 client 参数且 self.client 尚未初始化,则创建一个新的 OpenAI 客户端
        self.client = OpenAI()

        # 如果没有提供配置参数,则结束构造函数
        if config is None:
            return

        # 根据配置参数设置 OpenAI 客户端的不同属性
        if "api_type" in config:
            self.client.api_type = config["api_type"]  # 设置 API 类型

        if "api_base" in config:
            self.client.api_base = config["api_base"]  # 设置 API 基础 URL

        if "api_version" in config:
            self.client.api_version = config["api_version"]  # 设置 API 版本

        if "api_key" in config:
            self.client.api_key = config["api_key"]  # 设置 API 密钥

    # 定义生成嵌入的方法,接受输入数据 data 和其他可选参数 kwargs
    def generate_embedding(self, data: str, **kwargs) -> list[float]:
        # 如果配置参数中包含引擎配置,则使用该引擎生成嵌入
        if self.config is not None and "engine" in self.config:
            embedding = self.client.embeddings.create(
                engine=self.config["engine"],  # 使用指定的引擎
                input=data,  # 输入数据
            )
        else:
            # 否则使用默认的模型生成嵌入
            embedding = self.client.embeddings.create(
                model="text-embedding-ada-002",  # 使用默认模型
                input=data,  # 输入数据
            )

        # 返回生成的嵌入向量
        return embedding.get("data")[0]["embedding"]

opensearch

opensearch_vector.py

OpenSearch_VectorStore 类用于与 OpenSearch 进行交互,提供存储和检索文档、DDL(数据定义语言)和问题-SQL 对的方法。

import base64  # 导入 base64 模块,用于处理 base64 编码
import uuid  # 导入 uuid 模块,用于生成唯一标识符
from typing import List  # 导入 List 类型提示

import pandas as pd  # 导入 pandas 模块,用于数据处理
from opensearchpy import OpenSearch  # 导入 OpenSearch 模块,用于与 OpenSearch 服务交互

from ..base import VannaBase  # 从本地模块导入 VannaBase 类,这是一个基础类

# 定义 OpenSearch_VectorStore 类,继承自 VannaBase
class OpenSearch_VectorStore(VannaBase):
    def __init__(self, config=None):
        # 调用父类 VannaBase 的构造函数,并传递配置参数
        VannaBase.__init__(self, config=config)

        # 初始化索引名称
        document_index = "vanna_document_index"
        ddl_index = "vanna_ddl_index"
        question_sql_index = "vanna_questions_sql_index"
        
        # 从配置中获取自定义索引名称
        if config is not None and "es_document_index" in config:
            document_index = config["es_document_index"]
        if config is not None and "es_ddl_index" in config:
            ddl_index = config["es_ddl_index"]
        if config is not None and "es_question_sql_index" in config:
            question_sql_index = config["es_question_sql_index"]

        # 将索引名称保存为类的属性
        self.document_index = document_index
        self.ddl_index = ddl_index
        self.question_sql_index = question_sql_index
        print("OpenSearch_VectorStore initialized with document_index: ",
              document_index, " ddl_index: ", ddl_index, " question_sql_index: ",
              question_sql_index)

        # 定义默认索引设置
        document_index_settings = {
            "settings": {
                "index": {
                    "number_of_shards": 6,
                    "number_of_replicas": 2
                }
            },
            "mappings": {
                "properties": {
                    "question": {"type": "text"},
                    "doc": {"type": "text"}
                }
            }
        }

        ddl_index_settings = {
            "settings": {
                "index": {
                    "number_of_shards": 6,
                    "number_of_replicas": 2
                }
            },
            "mappings": {
                "properties": {
                    "ddl": {"type": "text"},
                    "doc": {"type": "text"}
                }
            }
        }

        question_sql_index_settings = {
            "settings": {
                "index": {
                    "number_of_shards": 6,
                    "number_of_replicas": 2
                }
            },
            "mappings": {
                "properties": {
                    "question": {"type": "text"},
                    "sql": {"type": "text"}
                }
            }
        }

        # 从配置中获取自定义索引设置
        if config is not None and "es_document_index_settings" in config:
            document_index_settings = config["es_document_index_settings"]
        if config is not None and "es_ddl_index_settings" in config:
            ddl_index_settings = config["es_ddl_index_settings"]
        if config is not None and "es_question_sql_index_settings" in config:
            question_sql_index_settings = config["es_question_sql_index_settings"]

        # 将索引设置保存为类的属性
        self.document_index_settings = document_index_settings
        self.ddl_index_settings = ddl_index_settings
        self.question_sql_index_settings = question_sql_index_settings

        # 初始化 OpenSearch 客户端
        es_urls = None
        if config is not None and "es_urls" in config:
            es_urls = config["es_urls"]

        # 获取主机和端口配置
        host = config["es_host"] if config and "es_host" in config else "localhost"
        port = config["es_port"] if config and "es_port" in config else 9200
        ssl = config["es_ssl"] if config and "es_ssl" in config else False
        verify_certs = config["es_verify_certs"] if config and "es_verify_certs" in config else False

        # 获取认证配置
        auth = (config["es_user"], config["es_password"]) if config and "es_user" in config else None

        # 基于 base64 的认证
        headers = None
        if config and "es_encoded_base64" in config and "es_user" in config and "es_password" in config:
            if config["es_encoded_base64"]:
                encoded_credentials = base64.b64encode(
                    (config["es_user"] + ":" + config["es_password"]).encode("utf-8")
                ).decode("utf-8")
                headers = {'Authorization': 'Basic ' + encoded_credentials}
                auth = None

        # 自定义 headers
        if config and "es_headers" in config:
            headers = config["es_headers"]

        # 获取超时和重试配置
        timeout = config["es_timeout"] if config and "es_timeout" in config else 60
        max_retries = config["es_max_retries"] if config and "es_max_retries" in config else 10
        es_http_compress = config["es_http_compress"] if config and "es_http_compress" in config else False

        print("OpenSearch_VectorStore initialized with es_urls: ", es_urls,
              " host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ",
              verify_certs, " timeout: ", timeout, " max_retries: ", max_retries)

        # 初始化 OpenSearch 客户端
        if es_urls is not None:
            self.client = OpenSearch(
                hosts=[es_urls],
                http_compress=es_http_compress,
                use_ssl=ssl,
                verify_certs=verify_certs,
                timeout=timeout,
                max_retries=max_retries,
                retry_on_timeout=True,
                http_auth=auth,
                headers=headers
            )
        else:
            self.client = OpenSearch(
                hosts=[{'host': host, 'port': port}],
                http_compress=es_http_compress,
                use_ssl=ssl,
                verify_certs=verify_certs,
                timeout=timeout,
                max_retries=max_retries,
                retry_on_timeout=True,
                http_auth=auth,
                headers=headers
            )

        print("OpenSearch_VectorStore initialized with client over ")

        # 执行一个简单的查询来检查连接
        try:
            print('Connected to OpenSearch cluster:')
            info = self.client.info()
            print('OpenSearch cluster info:', info)
        except Exception as e:
            print('Error connecting to OpenSearch cluster:', e)

        # 如果索引不存在,则创建索引
        self.create_index_if_not_exists(self.document_index, self.document_index_settings)
        self.create_index_if_not_exists(self.ddl_index, self.ddl_index_settings)
        self.create_index_if_not_exists(self.question_sql_index, self.question_sql_index_settings)

    # 创建索引的方法
    def create_index(self):
        for index in [self.document_index, self.ddl_index, self.question_sql_index]:
            try:
                self.client.indices.create(index)
            except Exception as e:
                print("Error creating index: ", e)
                print(f"opensearch index {index} already exists")
                pass

    # 如果索引不存在,则创建索引的方法
    def create_index_if_not_exists(self, index_name: str, index_settings: dict) -> bool:
        try:
            if not self.client.indices.exists(index_name):
                print(f"Index {index_name} does not exist. Creating...")
                self.client.indices.create(index=index_name, body=index_settings)
                return True
            else:
                print(f"Index {index_name} already exists.")
                return False
        except Exception as e:
            print(f"Error creating index: {index_name} ", e)
            return False

    # 添加 DDL 文档的方法
    def add_ddl(self, ddl: str, **kwargs) -> str:
        id = str(uuid.uuid4()) + "-ddl"  # 生成唯一标识符
        ddl_dict = {"ddl": ddl}
        response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id, **kwargs)
        return response['_id']

    # 添加文档的方法
    def add_documentation(self, doc: str, **kwargs) -> str:
        id = str(uuid.uuid4()) + "-doc"  # 生成唯一标识符
        doc_dict = {"doc": doc}
        response = self.client.index(index=self.document_index, id=id, body=doc_dict, **kwargs)
        return response['_id']

    # 添加问题和 SQL 的方法
    def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
        id = str(uuid.uuid4()) + "-sql"  # 生成唯一标识符
        question_sql_dict = {"question": question, "sql": sql}
        response = self.client.index(index=self.question_sql_index, body=question_sql_dict, id=id, **kwargs)
        return response['_id']

    # 获取相关 DDL 文档的方法
    def get_related_ddl(self, question: str, **kwargs) -> List[str]:
        query = {"query": {"match": {"ddl": question}}}
        print(query)
        response = self.client.search(index=self.ddl_index, body=query, **kwargs)
        return [hit['_source']['ddl'] for hit in response['hits']['hits']]

    # 获取相关文档的方法
    def get_related_documentation(self, question: str, **kwargs) -> List[str]:
        query = {"query": {"match": {"doc": question}}}
        print(query)
        response = self.client.search(index=self.document_index, body=query, **kwargs)
        return [hit['_source']['doc'] for hit in response['hits']['hits']]

    # 获取相似问题和 SQL 的方法
    def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
        query = {"query": {"match": {"question": question}}}
        print(query)
        response = self.client.search(index=self.question_sql_index, body=query, **kwargs)
        return [(hit['_source']['question'], hit['_source']['sql']) for hit in response['hits']['hits']]

    # 获取训练数据的方法
    def get_training_data(self, **kwargs) -> pd.DataFrame:
        data = []
        
        # 从文档索引中获取数据
        response = self.client.search(index=self.document_index, body={"query": {"match_all": {}}}, size=1000)
        for hit in response['hits']['hits']:
            data.append({"id": hit["_id"], "training_data_type": "documentation", "question": "", "content": hit["_source"]['doc']})
        
        # 从问题和 SQL 索引中获取数据
        response = self.client.search(index=self.question_sql_index, body={"query": {"match_all": {}}}, size=1000)
        for hit in response['hits']['hits']:
            data.append({"id": hit["_id"], "training_data_type": "sql", "question": hit.get("_source", {}).get("question", ""), "content": hit.get("_source", {}).get("sql", "")})
        
        # 从 DDL 索引中获取数据
        response = self.client.search(index=self.ddl_index, body={"query": {"match_all": {}}}, size=1000)
        for hit in response['hits']['hits']:
            data.append({"id": hit["_id"], "training_data_type": "ddl", "question": "", "content": hit["_source"]['ddl']})
        
        # 返回包含所有数据的 pandas DataFrame
        return pd.DataFrame(data)

    # 删除训练数据的方法
    def remove_training_data(self, id: str, **kwargs) -> bool:
        try:
            if id.endswith("-sql"):
                self.client.delete(index=self.question_sql_index, id=id)
                return True
            elif id.endswith("-ddl"):
                self.client.delete(index=self.ddl_index, id=id, **kwargs)
                return True
            elif id.endswith("-doc"):
                self.client.delete(index=self.document_index, id=id, **kwargs)
                return True
            else:
                return False
        except Exception as e:
            print("Error deleting training data: ", e)
            return False

    # 生成嵌入的方法(空方法)
    def generate_embedding(self, data: str, **kwargs) -> list[float]:
        pass  # OpenSearch 不需要生成嵌入

# 示例初始化调用
# OpenSearch_VectorStore.__init__(self, config={'es_urls': "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user': "admin", 'es_password': "admin", 'es_verify_certs': True})

# OpenSearch_VectorStore.__init__(self, config={'es_host': "https://opensearch-node.test.com", 'es_port': 9200, 'es_user': "admin", 'es_password': "admin", 'es_verify_certs': True})
功能和作用

1.初始化:

  • 配置索引名称和设置。

  • 初始化 OpenSearch 客户端,支持多种配置选项(如认证、超时、重试等)。

  • 创建必要的索引。
    2.索引管理:

  • create_index_if_not_exists:检查索引是否存在,不存在则创建。

  • create_index:创建指定的索引。
    3.文档操作:

  • add_ddl、add_documentation、add_question_sql:添加 DDL 文档、普通文档和问题-SQL 对。

  • get_related_ddl、get_related_documentation、get_similar_question_sql:检索相关的 DDL 文档、普通文档和问题-SQL 对。

  • get_training_data:获取所有训练数据,返回 pandas DataFrame。

  • remove_training_data:删除指定的训练数据。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值