前言
vanna可实现自然语言转SQL,尝试本地部署vanna对接数据库,将自然语言转成标准的SQL对数据库进行查询。本文先对vanna源码进行分析,该部分内容为接入各类ai方法。
一、vanna源码分析
pinecone_vector
PineconeDB_VectorStore 类用于使用 Pinecone 向量数据库存储和检索向量数据。这是通过与 Pinecone 客户端交互来实现的。
我们将 PineconeDB_VectorStore
代码分段进行详细注释和解析。
导入和类定义部分
import json # 导入 JSON 模块,用于处理 JSON 数据
from typing import List # 导入 List 类型提示
from pinecone import Pinecone, PodSpec, ServerlessSpec # 从 pinecone 库导入 Pinecone 客户端和规格类
import pandas as pd # 导入 pandas 模块,用于数据处理
from ..base import VannaBase # 从本地模块导入 VannaBase 基础类
from ..utils import deterministic_uuid # 从本地模块导入生成确定性 UUID 的函数
from fastembed import TextEmbedding # 从 fastembed 模块导入 TextEmbedding 类
- 导入了所需的模块和库,包括 JSON 处理、类型提示、Pinecone 客户端和规格、数据处理工具 pandas,以及一些本地模块。
class PineconeDB_VectorStore(VannaBase):
"""
使用 PineconeDB 的向量存储类
Args:
config (dict): 配置字典。必须提供 Pinecone 客户端或 API 密钥。
Raises:
ValueError: 如果配置无效,抛出错误。
"""
- 定义
PineconeDB_VectorStore
类,继承自VannaBase
,并在类文档中说明需要传递的配置参数和可能的异常。
初始化方法
def __init__(self, config=None):
# 调用父类 VannaBase 的构造函数
VannaBase.__init__(self, config=config)
if config is None:
raise ValueError("需要提供配置,传递 Pinecone 客户端或 API 密钥。")
client = config.get("client") # 获取配置中的客户端
api_key = config.get("api_key") # 获取配置中的 API 密钥
if not api_key and not client:
raise ValueError("需要在配置中提供 api_key 或传递已配置的客户端")
if not client and api_key:
self._client = Pinecone(api_key=api_key) # 使用 API 密钥初始化 Pinecone 客户端
elif not isinstance(client, Pinecone):
raise ValueError("client 必须是 Pinecone 的实例")
else:
self._client = client # 使用传递的 Pinecone 客户端
# 设置其他配置参数
self.n_results = config.get("n_results", 10)
self.dimensions = config.get("dimensions", 384)
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
self.documentation_namespace = config.get("documentation_namespace", "documentation")
self.distance_metric = config.get("distance_metric", "cosine")
self.ddl_namespace = config.get("ddl_namespace", "ddl")
self.sql_namespace = config.get("sql_namespace", "sql")
self.index_name = config.get("index_name", "vanna-index")
self.metadata_config = config.get("metadata_config", {})
self.server_type = config.get("server_type", "serverless")
if self.server_type not in ["serverless", "pod"]:
raise ValueError("server_type 必须是 'serverless' 或 'pod'")
# 配置 Pod 规格
self.podspec = config.get(
"podspec",
PodSpec(environment="us-west-2", pod_type="p1.x1", metadata_config=self.metadata_config),
)
# 配置无服务器规格
self.serverless_spec = config.get("serverless_spec", ServerlessSpec(cloud="aws", region="us-west-2"))
self._setup_index() # 设置索引
- 初始化方法中,调用了父类的构造函数。
- 检查配置是否为空,如果为空,则抛出
ValueError
。 - 根据配置初始化 Pinecone 客户端,确保传递的客户端是 Pinecone 的实例。
- 设置其他相关配置参数,如返回结果数、向量维度、模型名称、命名空间、距离度量等。
- 检查
server_type
是否为有效值(serverless
或pod
)。 - 配置 Pod 和无服务器规格,并调用
_setup_index
方法进行索引设置。
索引设置和检查方法
def _set_index_host(self, host: str) -> None:
self.Index = self._client.Index(host=host) # 设置索引主机
def _setup_index(self) -> None:
existing_indexes = self._get_indexes() # 获取现有索引
if self.index_name not in existing_indexes and self.server_type == "serverless":
self._client.create_index(
name=self.index_name,
dimension=self.dimensions,
metric=self.distance_metric,
spec=self.serverless_spec,
)
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
self._set_index_host(pinecone_index_host)
elif self.index_name not in existing_indexes and self.server_type == "pod":
self._client.create_index(
name=self.index_name,
dimension=self.dimensions,
metric=self.distance_metric,
spec=self.podspec,
)
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
self._set_index_host(pinecone_index_host)
else:
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
self._set_index_host(pinecone_index_host)
def _get_indexes(self) -> list:
return [index["name"] for index in self._client.list_indexes()] # 列出所有索引的名称
def _check_if_embedding_exists(self, id: str, namespace: str) -> bool:
fetch_response = self.Index.fetch(ids=[id], namespace=namespace) # 从指定命名空间获取嵌入
return fetch_response["vectors"] != {} # 如果嵌入存在,返回 True
_set_index_host
方法设置索引主机。_setup_index
方法根据配置创建或获取索引,如果索引不存在则创建索引,并设置索引主机。_get_indexes
方法列出所有现有索引的名称。_check_if_embedding_exists
方法检查嵌入是否已经存在于指定的命名空间中。
添加和查询数据的方法
def add_ddl(self, ddl: str, **kwargs) -> str:
id = deterministic_uuid(ddl) + "-ddl" # 生成唯一标识符
if self._check_if_embedding_exists(id=id, namespace=self.ddl_namespace):
print(f"DDL with id: {id} 已存在于索引中,跳过...")
return id
self.Index.upsert(
vectors=[(id, self.generate_embedding(ddl), {"ddl": ddl})],
namespace=self.ddl_namespace,
)
return id
def add_documentation(self, doc: str, **kwargs) -> str:
id = deterministic_uuid(doc) + "-doc" # 生成唯一标识符
if self._check_if_embedding_exists(id=id, namespace=self.documentation_namespace):
print(f"Documentation with id: {id} 已存在于索引中,跳过...")
return id
self.Index.upsert(
vectors=[(id, self.generate_embedding(doc), {"documentation": doc})],
namespace=self.documentation_namespace,
)
return id
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_sql_json = json.dumps({"question": question, "sql": sql}, ensure_ascii=False)
id = deterministic_uuid(question_sql_json) + "-sql" # 生成唯一标识符
if self._check_if_embedding_exists(id=id, namespace=self.sql_namespace):
print(f"Question-SQL with id: {id} 已存在于索引中,跳过...")
return id
self.Index.upsert(
vectors=[(id, self.generate_embedding(question_sql_json), {"sql": question_sql_json})],
namespace=self.sql_namespace,
)
return id
add_ddl
方法添加 DDL 数据到指定命名空间,如果数据已经存在则跳过。add_documentation
方法添加文档数据到指定命名空间,如果数据已经存在则跳过。add_question_sql
方法添加问题和 SQL 数据到指定命名空间,如果数据已经存在则跳过。
查询相关数据的方法
def get_related_ddl(self, question: str, **kwargs) -> list:
res = self.Index.query(
namespace=self.ddl_namespace, # 使用DDL命名空间
vector=self.generate_embedding(question), # 生成问题的嵌入向量
top_k=self.n_results, # 返回的结果数量
include_values=True,
include_metadata=True,
)
# 如果查询有结果,返回匹配的DDL语句列表,否则返回空列表
return [match["metadata"]["ddl"] for match in res["matches"]] if res else []
获取相关文档方法
def get_related_documentation(self, question: str, **kwargs) -> list:
res = self.Index.query(
namespace=self.documentation_namespace, # 使用文档命名空间
vector=self.generate_embedding(question), # 生成问题的嵌入向量
top_k=self.n_results, # 返回的结果数量
include_values=True,
include_metadata=True,
)
# 如果查询有结果,返回匹配的文档列表,否则返回空列表
return (
[match["metadata"]["documentation"] for match in res["matches"]]
if res
else []
)
获取相似问题和SQL方法
def get_similar_question_sql(self, question: str, **kwargs) -> list:
res = self.Index.query(
namespace=self.sql_namespace, # 使用SQL命名空间
vector=self.generate_embedding(question), # 生成问题的嵌入向量
top_k=self.n_results, # 返回的结果数量
include_values=True,
include_metadata=True,
)
# 如果查询有结果,解析并返回匹配的Question-SQL对列表,否则返回空列表
return (
[
{
key: value
for key, value in json.loads(match["metadata"]["sql"]).items()
}
for match in res["matches"]
]
if res
else []
)
获取训练数据方法
def get_training_data(self, **kwargs) -> pd.DataFrame:
df = pd.DataFrame()
namespaces = {
"sql": self.sql_namespace,
"ddl": self.ddl_namespace,
"documentation": self.documentation_namespace,
}
for data_type, namespace in namespaces.items():
data = self.Index.query(
top_k=10000, # Pinecone允许的最大结果数量
namespace=namespace, # 当前命名空间
include_values=True,
include_metadata=True,
vector=[0.0] * self.dimensions, # 使用零向量查询
)
if data is not None:
id_list = [match["id"] for match in data["matches"]]
content_list = [
match["metadata"][data_type] for match in data["matches"]
]
question_list = [
(
json.loads(match["metadata"][data_type])["question"]
if data_type == "sql"
else None
)
for match in data["matches"]
]
df_data = pd.DataFrame(
{
"id": id_list,
"question": question_list,
"content": content_list,
}
)
df_data["training_data_type"] = data_type
df = pd.concat([df, df_data])
# 返回包含所有训练数据的DataFrame
return df
移除训练数据方法
def remove_training_data(self, id: str, **kwargs) -> bool:
if id.endswith("-sql"):
self.Index.delete(ids=[id], namespace=self.sql_namespace)
return True
elif id.endswith("-ddl"):
self.Index.delete(ids=[id], namespace=self.ddl_namespace)
return True
elif id.endswith("-doc"):
self.Index.delete(ids=[id], namespace=self.documentation_namespace)
return True
else:
return False
# 根据ID的后缀判断并从相应的命名空间中删除数据
生成嵌入方法
def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding_model = TextEmbedding(model_name=self.fastembed_model)
embedding = next(embedding_model.embed(data))
return embedding.tolist()
# 使用指定的Fastembed模型生成数据的嵌入向量,并返回生成的嵌入向量列表
qdrant
这个类实现了一个基于Qdrant的向量存储,用于存储和检索向量化的数据,如DDL(数据定义语言)、文档和SQL查询。通过该类,用户可以将文本数据转换为向量并存储在Qdrant中,并根据相似性进行检索。
初始化方法
def __init__(self, config={}):
VannaBase.__init__(self, config=config) # 调用父类的初始化方法
client = config.get("client") # 从配置中获取客户端
if client is None:
self._client = QdrantClient( # 如果没有提供客户端,则创建一个新的QdrantClient实例
location=config.get("location", None),
url=config.get("url", None),
prefer_grpc=config.get("prefer_grpc", False),
https=config.get("https", None),
api_key=config.get("api_key", None),
timeout=config.get("timeout", None),
path=config.get("path", None),
prefix=config.get("prefix", None),
)
elif not isinstance(client, QdrantClient):
raise TypeError( # 如果提供的客户端不是QdrantClient实例,则抛出类型错误
f"Unsupported client of type {client.__class__} was set in config"
)
else:
self._client = client # 使用提供的QdrantClient实例
self.n_results = config.get("n_results", 10) # 设置返回结果数量,默认为10
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5") # 设置嵌入模型
self.collection_params = config.get("collection_params", {}) # 设置集合参数
self.distance_metric = config.get("distance_metric", models.Distance.COSINE) # 设置距离度量,默认为余弦距离
self.documentation_collection_name = config.get(
"documentation_collection_name", "documentation"
)
self.ddl_collection_name = config.get(
"ddl_collection_name", "ddl"
)
self.sql_collection_name = config.get(
"sql_collection_name", "sql"
)
self.id_suffixes = {
self.ddl_collection_name: "ddl",
self.documentation_collection_name: "doc",
self.sql_collection_name: "sql",
}
self._setup_collections() # 调用方法设置集合
- 初始化
Qdrant_VectorStore
类实例。 - 根据配置参数初始化 Qdrant 客户端。
- 设置返回结果数量、嵌入模型、集合参数和距离度量。
- 定义集合名称和 ID 后缀。
- 调用
_setup_collections
方法设置集合。
添加问题-SQL对
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_answer = "Question: {0}\n\nSQL: {1}".format(question, sql) # 组合问题和SQL对
id = deterministic_uuid(question_answer) # 生成唯一ID
self._client.upsert(
self.sql_collection_name,
points=[
models.PointStruct(
id=id,
vector=self.generate_embedding(question_answer), # 生成嵌入向量
payload={
"question": question,
"sql": sql,
},
)
],
)
return self._format_point_id(id, self.sql_collection_name) # 返回格式化的点ID
- 添加问题和SQL对到SQL集合。
- 生成问题和SQL对的嵌入并插入到集合中。
- 返回格式化的点ID。
添加DDL
def add_ddl(self, ddl: str, **kwargs) -> str:
id = deterministic_uuid(ddl) # 生成唯一ID
self._client.upsert(
self.ddl_collection_name,
points=[
models.PointStruct(
id=id,
vector=self.generate_embedding(ddl), # 生成嵌入向量
payload={
"ddl": ddl,
},
)
],
)
return self._format_point_id(id, self.ddl_collection_name) # 返回格式化的点ID
- 添加DDL到DDL集合。
- 生成DDL的嵌入并插入到集合中。
- 返回格式化的点ID。
添加文档
def add_documentation(self, documentation: str, **kwargs) -> str:
id = deterministic_uuid(documentation) # 生成唯一ID
self._client.upsert(
self.documentation_collection_name,
points=[
models.PointStruct(
id=id,
vector=self.generate_embedding(documentation), # 生成嵌入向量
payload={
"documentation": documentation,
},
)
],
)
return self._format_point_id(id, self.documentation_collection_name) # 返回格式化的点ID
- 添加文档到文档集合。
- 生成文档的嵌入并插入到集合中。
- 返回格式化的点ID。
获取训练数据
def get_training_data(self, **kwargs) -> pd.DataFrame:
df = pd.DataFrame() # 初始化空的DataFrame
if sql_data := self._get_all_points(self.sql_collection_name): # 获取SQL集合中的所有点
question_list = [data.payload["question"] for data in sql_data]
sql_list = [data.payload["sql"] for data in sql_data]
id_list = [
self._format_point_id(data.id, self.sql_collection_name)
for data in sql_data
]
df_sql = pd.DataFrame(
{
"id": id_list,
"question": question_list,
"content": sql_list,
}
)
df_sql["training_data_type"] = "sql"
df = pd.concat([df, df_sql]) # 合并到主DataFrame
if ddl_data := self._get_all_points(self.ddl_collection_name): # 获取DDL集合中的所有点
ddl_list = [data.payload["ddl"] for data in ddl_data]
id_list = [
self._format_point_id(data.id, self.ddl_collection_name)
for data in ddl_data
]
df_ddl = pd.DataFrame(
{
"id": id_list,
"question": [None for _ in ddl_list],
"content": ddl_list,
}
)
df_ddl["training_data_type"] = "ddl"
df = pd.concat([df, df_ddl]) # 合并到主DataFrame
if doc_data := self._get_all_points(self.documentation_collection_name): # 获取文档集合中的所有点
document_list = [data.payload["documentation"] for data in doc_data]
id_list = [
self._format_point_id(data.id, self.documentation_collection_name)
for data in doc_data
]
df_doc = pd.DataFrame(
{
"id": id_list,
"question": [None for _ in document_list],
"content": document_list,
}
)
df_doc["training_data_type"] = "documentation"
df = pd.concat([df, df_doc]) # 合并到主DataFrame
return df # 返回包含所有训练数据的DataFrame
- 获取所有训练数据并返回一个包含所有数据的DataFrame。
- 分别从SQL集合、DDL集合和文档集合中获取数据。
- 将数据组合成DataFrame格式。
移除训练数据
def remove_training_data(self, id: str, **kwargs) -> bool:
try:
id, collection_name = self._parse_point_id(id) # 解析点ID获取集合名称
res = self._client.delete(collection_name, points_selector=[id])
return True
except ValueError:
return False
- 根据ID移除训练数据。
- 解析点ID获取集合名称。
- 从指定集合中删除数据。
移除集合
def remove_collection(self, collection_name: str) -> bool:
if collection_name in self.id_suffixes.keys():
self._client.delete_collection(collection_name) # 删除集合
self._setup_collections() # 重新设置集合
return True
else:
return False
- 移除指定的集合并重置集合为空状态。
- 根据集合名称删除集合并重新设置集合。
生成嵌入
def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding_model = self._client._get_or_init_model(
model_name=self.fastembed_model
)
embedding = next(embedding_model.embed(data))
return embedding.tolist() # 返回生成的嵌入向量
- 生成数据的嵌入向量。
- 使用指定的Fastembed模型生成数据的嵌入向量并返回。
获取相似问题和SQL方法
def get_similar_question_sql(self, question: str, **kwargs) -> list:
results = self._client.search(
self.sql_collection_name, # 使用SQL集合名称
query_vector=self.generate_embedding(question), # 生成问题的嵌入向量
limit=self.n_results, # 返回的结果数量限制
with_payload=True,
)
return [dict(result.payload) for result in results] # 返回结果列表,其中包含查询的Payload
获取相关DDL方法
def get_related_ddl(self, question: str, **kwargs) -> list:
results = self._client.search(
self.ddl_collection_name, # 使用DDL集合名称
query_vector=self.generate_embedding(question), # 生成问题的嵌入向量
limit=self.n_results, # 返回的结果数量限制
with_payload=True,
)
# 返回结果列表,其中包含DDL字段
return [result.payload["ddl"] for result in results]
- 根据输入的问题,在DDL集合中查询相关的DDL,并返回结果。
获取相关文档方法
def get_related_documentation(self, question: str, **kwargs) -> list:
results = self._client.search(
self.documentation_collection_name, # 使用文档集合名称
query_vector=self.generate_embedding(question), # 生成问题的嵌入向量
limit=self.n_results, # 返回的结果数量限制
with_payload=True,
)
# 返回结果列表,其中包含文档字段
return [result.payload["documentation"] for result in results]
- 根据输入的问题,在文档集合中查询相关的文档,并返回结果。
嵌入维度
@cached_property
def embeddings_dimension(self):
return len(self.generate_embedding("ABCDEF")) # 返回嵌入向量的维度
- 返回嵌入向量的维度。
获取所有点的方法
def _get_all_points(self, collection_name: str):
results: List[models.Record] = []
next_offset = None
stop_scrolling = False
while not stop_scrolling:
records, next_offset = self._client.scroll(
collection_name, # 使用指定集合名称
limit=SCROLL_SIZE, # 每次滚动获取的记录数量
offset=next_offset, # 当前滚动偏移量
with_payload=True,
with_vectors=False,
)
stop_scrolling = next_offset is None or (
isinstance(next_offset, grpc.PointId)
and next_offset.num == 0
and next_offset.uuid == ""
)
results.extend(records)
return results # 返回所有记录的列表
- 获取指定集合中的所有点记录。
设置集合的方法
def _setup_collections(self):
if not self._client.collection_exists(self.sql_collection_name):
self._client.create_collection(
collection_name=self.sql_collection_name, # 创建SQL集合
vectors_config=models.VectorParams(
size=self.embeddings_dimension, # 嵌入向量的维度
distance=self.distance_metric, # 使用的距离度量
),
**self.collection_params,
)
if not self._client.collection_exists(self.ddl_collection_name):
self._client.create_collection(
collection_name=self.ddl_collection_name, # 创建DDL集合
vectors_config=models.VectorParams(
size=self.embeddings_dimension, # 嵌入向量的维度
distance=self.distance_metric, # 使用的距离度量
),
**self.collection_params,
)
if not self._client.collection_exists(self.documentation_collection_name):
self._client.create_collection(
collection_name=self.documentation_collection_name, # 创建文档集合
vectors_config=models.VectorParams(
size=self.embeddings_dimension, # 嵌入向量的维度
distance=self.distance_metric, # 使用的距离度量
),
**self.collection_params,
)
- 设置SQL、DDL和文档集合,如果集合不存在则创建集合。
格式化点ID的方法
def _format_point_id(self, id: str, collection_name: str) -> str:
return "{0}-{1}".format(id, self.id_suffixes[collection_name]) # 返回格式化的点ID
- 格式化点ID,添加集合后缀。
解析点ID的方法
def _parse_point_id(self, id: str) -> Tuple[str, str]:
id, curr_suffix = id.rsplit("-", 1)
for collection_name, suffix in self.id_suffixes.items():
if curr_suffix == suffix:
return id, collection_name
raise ValueError(f"Invalid id {id}") # 如果ID无效则抛出异常
- 解析点ID,返回ID和集合名称。
types
该代码定义了一些数据类和一个训练计划类,用于表示各种数据结构和训练计划。
-
数据类:
- 使用
@dataclass
装饰器定义了一系列数据类,如Status
、QuestionList
、FullQuestionDocument
等。 - 这些类用于描述和存储不同类型的数据,如问题、答案、组织信息、数据结果等。
- 使用
-
训练计划类:
TrainingPlanItem
类表示训练计划中的一个项目。TrainingPlan
类表示一个训练计划,可以获取计划的概要并移除不需要的项目。
数据类和训练计划类解析
下面是提供的代码的详细解析,包括逐行注释:
from __future__ import annotations # 允许在当前文件中使用未来版本的特性
from dataclasses import dataclass # 从dataclasses模块导入dataclass装饰器
from typing import Dict, List, Union # 导入类型注解
@dataclass
class Status:
success: bool # 操作是否成功
message: str # 返回的信息
@dataclass
class StatusWithId:
success: bool # 操作是否成功
message: str # 返回的信息
id: str # 关联的ID
@dataclass
class QuestionList:
questions: List[FullQuestionDocument] # 包含问题文档的列表
@dataclass
class FullQuestionDocument:
id: QuestionId # 问题ID
question: Question # 问题内容
answer: SQLAnswer | None # SQL答案(如果有)
data: DataResult | None # 数据结果(如果有)
plotly: PlotlyResult | None # Plotly结果(如果有)
@dataclass
class QuestionSQLPair:
question: str # 问题内容
sql: str # 对应的SQL查询
tag: Union[str, None] # 标签(可选)
@dataclass
class Organization:
name: str # 组织名称
user: str | None # 用户名称(可选)
connection: Connection | None # 连接(可选)
@dataclass
class OrganizationList:
organizations: List[str] # 组织名称列表
@dataclass
class QuestionStringList:
questions: List[str] # 问题内容列表
@dataclass
class Visibility:
visibility: bool # 可见性状态
@dataclass
class UserEmail:
email: str # 用户电子邮件
@dataclass
class NewOrganization:
org_name: str # 新组织名称
db_type: str # 数据库类型
@dataclass
class NewOrganizationMember:
org_name: str # 组织名称
email: str # 成员电子邮件
is_admin: bool # 是否为管理员
@dataclass
class UserOTP:
email: str # 用户电子邮件
otp: str # 一次性密码
@dataclass
class ApiKey:
key: str # API密钥
@dataclass
class QuestionId:
id: str # 问题ID
@dataclass
class Question:
question: str # 问题内容
@dataclass
class QuestionCategory:
question: str # 问题内容
category: str # 问题类别
# 预定义的类别常量
NO_SQL_GENERATED = "No SQL Generated"
SQL_UNABLE_TO_RUN = "SQL Unable to Run"
BOOTSTRAP_TRAINING_QUERY = "Bootstrap Training Query"
SQL_RAN = "SQL Ran Successfully"
FLAGGED_FOR_REVIEW = "Flagged for Review"
REVIEWED_AND_APPROVED = "Reviewed and Approved"
REVIEWED_AND_REJECTED = "Reviewed and Rejected"
REVIEWED_AND_UPDATED = "Reviewed and Updated"
@dataclass
class AccuracyStats:
num_questions: int # 问题数量
data: Dict[str, int] # 各类别的问题数量
@dataclass
class Followup:
followup: str # 后续问题
@dataclass
class QuestionEmbedding:
question: Question # 问题内容
embedding: List[float] # 嵌入向量
@dataclass
class Connection:
# TODO: 实现连接类
pass
@dataclass
class SQLAnswer:
raw_answer: str # 原始答案
prefix: str # 前缀
postfix: str # 后缀
sql: str # SQL查询
@dataclass
class Explanation:
explanation: str # 解释内容
@dataclass
class DataResult:
question: str | None # 问题内容(可选)
sql: str | None # SQL查询(可选)
table_markdown: str # 表格的Markdown表示
error: str | None # 错误信息(可选)
correction_attempts: int # 修正尝试次数
@dataclass
class PlotlyResult:
plotly_code: str # Plotly代码
@dataclass
class WarehouseDefinition:
name: str # 仓库名称
tables: List[TableDefinition] # 表定义列表
@dataclass
class TableDefinition:
schema_name: str # 模式名称
table_name: str # 表名称
ddl: str | None # DDL语句(可选)
columns: List[ColumnDefinition] # 列定义列表
@dataclass
class ColumnDefinition:
name: str # 列名称
type: str # 列类型
is_primary_key: bool # 是否为主键
is_foreign_key: bool # 是否为外键
foreign_key_table: str # 外键表
foreign_key_column: str # 外键列
@dataclass
class Diagram:
raw: str # 原始内容
mermaid_code: str # Mermaid代码
@dataclass
class StringData:
data: str # 字符串数据
@dataclass
class DataFrameJSON:
data: str # JSON格式的数据
@dataclass
class TrainingData:
questions: List[dict] # 问题列表
ddl: List[str] # DDL语句列表
documentation: List[str] # 文档列表
@dataclass
class TrainingPlanItem:
item_type: str # 项目类型
item_group: str # 项目组
item_name: str # 项目名称
item_value: str # 项目值
def __str__(self):
if self.item_type == self.ITEM_TYPE_SQL:
return f"Train on SQL: {self.item_group} {self.item_name}"
elif self.item_type == self.ITEM_TYPE_DDL:
return f"Train on DDL: {self.item_group} {self.item_name}"
elif self.item_type == self.ITEM_TYPE_IS:
return f"Train on Information Schema: {self.item_group} {self.item_name}"
ITEM_TYPE_SQL = "sql"
ITEM_TYPE_DDL = "ddl"
ITEM_TYPE_IS = "is"
class TrainingPlan:
"""
A class representing a training plan. You can see what's in it, and remove items from it that you don't want trained.
**Example:**
```python
plan = vn.get_training_plan()
plan.get_summary()
```
"""
_plan: List[TrainingPlanItem]
def __init__(self, plan: List[TrainingPlanItem]):
self._plan = plan # 初始化训练计划
def __str__(self):
return "\n".join(self.get_summary()) # 返回训练计划的字符串表示
def __repr__(self):
return self.__str__() # 返回训练计划的字符串表示
def get_summary(self) -> List[str]:
"""
**Example:**
```python
plan = vn.get_training_plan()
plan.get_summary()
```
Get a summary of the training plan.
Returns:
List[str]: A list of strings describing the training plan.
"""
return [f"{item}" for item in self._plan] # 返回训练计划的概要
def remove_item(self, item: str):
"""
**Example:**
```python
plan = vn.get_training_plan()
plan.remove_item("Train on SQL: What is the average salary of employees?")
```
Remove an item from the training plan.
Args:
item (str): The item to remove.
"""
for plan_item in self._plan:
if str(plan_item) == item:
self._plan.remove(plan_item) # 从训练计划中移除指定项
break
vannadb_vector
VannaDB_VectorStore类
该类实现了一个基于Vanna API的向量存储,用于存储和检索向量化的数据,需要与Vanna官网通信。
这个类通过HTTP请求与Vanna的API进行交互,包括以下操作:
- RPC调用:通过POST请求与Vanna的RPC端点通信,以调用各种方法(如add_sql、get_training_data等)。
- GraphQL调用:通过POST请求与Vanna的GraphQL端点通信,以执行查询和变更(如获取函数、创建函数、更新函数等)。
初始化与配置
def __init__(self, vanna_model: str, vanna_api_key: str, config=None):
VannaBase.__init__(self, config=config) # 调用父类的初始化方法
self._model = vanna_model # 设置模型名称
self._api_key = vanna_api_key # 设置API密钥
self._endpoint = (
"https://ask.vanna.ai/rpc"
if config is None or "endpoint" not in config
else config["endpoint"]
)
self.related_training_data = {} # 初始化相关训练数据字典
self._graphql_endpoint = "https://functionrag.com/query" # 设置GraphQL端点
self._graphql_headers = {
"Content-Type": "application/json",
"API-KEY": self._api_key,
"NAMESPACE": self._model,
}
- 初始化
VannaDB_VectorStore
类实例。 - 根据配置参数初始化API端点和相关头信息。
- 设置模型名称和API密钥。
RPC调用方法
def _rpc_call(self, method, params):
if method != "list_orgs": # 设置请求头
headers = {
"Content-Type": "application/json",
"Vanna-Key": self._api_key,
"Vanna-Org": self._model,
}
else:
headers = {
"Content-Type": "application/json",
"Vanna-Key": self._api_key,
"Vanna-Org": "demo-tpc-h",
}
data = {
"method": method,
"params": [self._dataclass_to_dict(obj) for obj in params], # 将参数转换为字典
}
response = requests.post(self._endpoint, headers=headers, data=json.dumps(data)) # 发送POST请求
return response.json() # 返回JSON响应
- 发送RPC请求到Vanna API。
- 设置请求头和请求数据。
- 返回API响应的JSON数据。
数据类转换方法
def _dataclass_to_dict(self, obj):
return dataclasses.asdict(obj) # 将数据类对象转换为字典
- 将数据类对象转换为字典格式。
获取所有函数方法
def get_all_functions(self) -> list:
query = """
{
get_all_sql_functions {
function_name
description
post_processing_code_template
arguments {
name
description
general_type
is_user_editable
available_values
}
sql_template
}
}
"""
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query}) # 发送GraphQL请求
response_json = response.json() # 获取JSON响应
if response.status_code == 200 and 'data' in response_json and 'get_all_sql_functions' in response_json['data']:
self.log(response_json['data']['get_all_sql_functions']) # 记录日志
resp = response_json['data']['get_all_sql_functions']
print(resp)
return resp # 返回函数列表
else:
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
- 发送GraphQL查询请求以获取所有SQL函数。
- 返回函数列表。
获取函数方法
def get_function(self, question: str, additional_data: dict = {}) -> dict:
query = """
query GetFunction($question: String!, $staticFunctionArguments: [StaticFunctionArgument]) {
get_and_instantiate_function(question: $question, static_function_arguments: $staticFunctionArguments) {
... on SQLFunction {
function_name
description
post_processing_code_template
instantiated_post_processing_code
arguments {
name
description
general_type
is_user_editable
instantiated_value
available_values
}
sql_template
instantiated_sql
}
}
}
"""
static_function_arguments = [{"name": key, "value": str(value)} for key, value in additional_data.items()]
variables = {"question": question, "staticFunctionArguments": static_function_arguments}
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables}) # 发送GraphQL请求
response_json = response.json() # 获取JSON响应
if response.status_code == 200 and 'data' in response_json and 'get_and_instantiate_function' in response_json['data']:
self.log(response_json['data']['get_and_instantiate_function']) # 记录日志
resp = response_json['data']['get_and_instantiate_function']
print(resp)
return resp # 返回函数信息
else:
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
- 发送GraphQL查询请求以获取特定问题的函数。
- 返回函数信息。
创建函数方法
def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
query = """
mutation CreateFunction($question: String!, $sql: String!, $plotly_code: String!) {
generate_and_create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
function_name
description
arguments {
name
description
general_type
is_user_editable
}
sql_template
post_processing_code_template
}
}
"""
variables = {"question": question, "sql": sql, "plotly_code": plotly_code}
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables}) # 发送GraphQL请求
response_json = response.json() # 获取JSON响应
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'generate_and_create_sql_function' in response_json['data']:
resp = response_json['data']['generate_and_create_sql_function']
print(resp)
return resp # 返回新创建的函数信息
else:
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")
- 发送GraphQL变更请求以创建新的SQL函数。
- 返回新创建的函数信息。
更新函数方法
def update_function(self, old_function_name: str, updated_function: dict) -> bool:
mutation = """
mutation UpdateSQLFunction($input: SQLFunctionUpdate!) {
update_sql_function(input: $input)
}
"""
SQLFunctionUpdate = {
'function_name', 'description', 'arguments', 'sql_template', 'post_processing_code_template'
}
ArgumentKeys = {'name', 'general_type', 'description', 'is_user_editable', 'available_values'}
def validate_arguments(args):
return [
{key: arg[key] for key in arg if key in ArgumentKeys}
for arg in args
]
updated_function = {key: value for key, value in updated_function.items() if key in SQLFunctionUpdate}
if 'arguments' in updated_function:
updated_function['arguments'] = validate_arguments(updated_function['arguments'])
variables = {
"input": {
"old_function_name": old_function_name,
**updated_function
}
}
print("variables", variables)
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables}) # 发送GraphQL请求
response_json = response.json() # 获取JSON响应
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'update_sql_function' in response_json['data']:
return response_json['data']['update_sql_function'] # 返回更新结果
else:
raise Exception(f"Mutation failed to run by returning code of {response.status_code}. {response.text}")
- 发送GraphQL变更请求以更新现有的SQL函数。
- 返回更新结果。
删除函数方法
def delete_function(self, function_name: str) -> bool:
mutation = """
mutation DeleteSQLFunction($function_name: String!) {
delete_sql_function(function_name: $function_name)
}
"""
variables = {"function_name": function_name} # 设置变量
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables}) # 发送POST请求
response_json = response.json() # 获取JSON响应
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'delete_sql_function' in response_json['data']:
return response_json['data']['delete_sql_function'] # 返回删除结果
else:
raise Exception
创建模型方法
def create_model(self, model: str, **kwargs) -> bool:
model = sanitize_model_name(model) # 清理模型名称
params = [NewOrganization(org_name=model, db_type="")]
d = self._rpc_call(method="create_org", params=params) # 调用RPC创建组织
if "result" not in d:
return False
status = Status(**d["result"])
return status.success # 返回创建状态
- 调用RPC创建新的模型组织。
- 返回创建状态。
获取模型方法
def get_models(self) -> list:
d = self._rpc_call(method="list_my_models", params=[]) # 调用RPC列出模型
if "result" not in d:
return []
orgs = OrganizationList(**d["result"])
return orgs.organizations # 返回模型列表
- 调用RPC获取用户的所有模型。
- 返回模型列表。
生成嵌入方法
def generate_embedding(self, data: str, **kwargs) -> list[float]:
pass # 在服务器端生成嵌入
- 生成嵌入向量的方法占位符,实际实现由服务器端完成。
添加问题和SQL方法
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
tag = kwargs.get("tag", "Manually Trained") # 获取标签,默认为“Manually Trained”
params = [QuestionSQLPair(question=question, sql=sql, tag=tag)]
d = self._rpc_call(method="add_sql", params=params) # 调用RPC添加问题和SQL对
if "result" not in d:
raise Exception("Error adding question and SQL pair", d)
status = StatusWithId(**d["result"])
return status.id # 返回新添加的ID
- 调用RPC添加问题和SQL对。
- 返回新添加的ID。
添加DDL方法
def add_ddl(self, ddl: str, **kwargs) -> str:
params = [StringData(data=ddl)]
d = self._rpc_call(method="add_ddl", params=params) # 调用RPC添加DDL
if "result" not in d:
raise Exception("Error adding DDL", d)
status = StatusWithId(**d["result"])
return status.id # 返回新添加的ID
- 调用RPC添加DDL语句。
- 返回新添加的ID。
添加文档方法
def add_documentation(self, documentation: str, **kwargs) -> str:
params = [StringData(data=documentation)]
d = self._rpc_call(method="add_documentation", params=params) # 调用RPC添加文档
if "result" not in d:
raise Exception("Error adding documentation", d)
status = StatusWithId(**d["result"])
return status.id # 返回新添加的ID
- 调用RPC添加文档。
- 返回新添加的ID。
获取训练数据方法
def get_training_data(self, **kwargs) -> pd.DataFrame:
params = []
d = self._rpc_call(method="get_training_data", params=params) # 调用RPC获取训练数据
if "result" not in d:
return None
training_data = DataFrameJSON(**d["result"])
df = pd.read_json(StringIO(training_data.data)) # 将JSON数据读取为DataFrame
return df # 返回DataFrame
- 调用RPC获取训练数据。
- 返回训练数据的DataFrame。
移除训练数据方法
def remove_training_data(self, id: str, **kwargs) -> bool:
params = [StringData(data=id)]
d = self._rpc_call(method="remove_training_data", params=params) # 调用RPC移除训练数据
if "result" not in d:
raise Exception("Error removing training data")
status = Status(**d["result"])
if not status.success:
raise Exception(f"Error removing training data: {status.message}")
return status.success # 返回移除状态
- 调用RPC移除指定ID的训练数据。
- 返回移除状态。
获取缓存的相关训练数据方法
def get_related_training_data_cached(self, question: str) -> TrainingData:
params = [Question(question=question)]
d = self._rpc_call(method="get_related_training_data", params=params) # 调用RPC获取相关训练数据
if "result" not in d:
return None
training_data = TrainingData(**d["result"])
self.related_training_data[question] = training_data # 缓存相关训练数据
return training_data # 返回训练数据
- 调用RPC获取相关训练数据,并将其缓存。
- 返回训练数据。
获取相似问题和SQL方法
def get_similar_question_sql(self, question: str, **kwargs) -> list:
if question in self.related_training_data:
training_data = self.related_training_data[question]
else:
training_data = self.get_related_training_data_cached(question) # 从缓存获取相关训练数据
return training_data.questions # 返回相似问题和SQL对
- 获取与指定问题相似的问题和SQL对。
- 如果缓存中存在相关数据则直接使用,否则通过RPC获取并缓存。
获取相关DDL方法
def get_related_ddl(self, question: str, **kwargs) -> list:
if question in self.related_training_data:
training_data = self.related_training_data[question]
else:
training_data = self.get_related_training_data_cached(question) # 从缓存获取相关训练数据
return training_data.ddl # 返回相关DDL
- 获取与指定问题相关的DDL。
- 如果缓存中存在相关数据则直接使用,否则通过RPC获取并缓存。
获取相关文档方法
def get_related_documentation(self, question: str, **kwargs) -> list:
if question in self.related_training_data:
training_data = self.related_training_data[question]
else:
training_data = self.get_related_training_data_cached(question) # 从缓存获取相关训练数据
return training_data.documentation # 返回相关文档
- 获取与指定问题相关的文档。
- 如果缓存中存在相关数据则直接使用,否则通过RPC获取并缓存。
vllm
这个类提供了与VLLM(假设为一个虚拟语言模型)进行交互的功能。通过配置参数初始化类实例,并提供多种方法以便与VLLM进行通信和处理。
- 该类通过配置参数初始化,并提供与VLLM服务交互的方法。
- 可以生成系统消息、用户消息和助手消息。
- 提供提取SQL查询和生成SQL的方法。
- 通过提交提示与VLLM服务通信,获取生成的消息内容。
初始化方法
def __init__(self, config=None):
if config is None or "vllm_host" not in config:
self.host = "http://localhost:8000" # 设置默认主机地址
else:
self.host = config["vllm_host"] # 使用配置中的主机地址
if config is None or "model" not in config:
raise ValueError("check the config for vllm") # 检查配置中的模型参数
else:
self.model = config["model"] # 设置模型名称
if "auth-key" in config:
self.auth_key = config["auth-key"] # 设置认证密钥
else:
self.auth_key = None # 如果没有认证密钥,则设为None
- 根据配置参数初始化类实例。
- 检查并设置主机地址、模型名称和认证密钥。
消息生成方法
def system_message(self, message: str) -> any:
return {"role": "system", "content": message} # 生成系统消息
def user_message(self, message: str) -> any:
return {"role": "user", "content": message} # 生成用户消息
def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message} # 生成助手消息
- 生成系统消息、用户消息和助手消息,分别返回不同角色的消息字典。
SQL查询提取方法
def extract_sql_query(self, text):
"""
提取第一个SQL语句,该语句在'`select`'之后,忽略大小写,
匹配直到第一个分号、三个反引号或字符串末尾,
并在提取的字符串中去除三个反引号。
Args:
- text (str): 要搜索的字符串。
Returns:
- str: 提取的第一个SQL语句,去除三个反引号,如果没有匹配则返回空字符串。
"""
# 正则表达式查找'`select`'(忽略大小写),并捕获直到';'、'```'或字符串末尾
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
match = pattern.search(text)
if match:
# 如果存在匹配,则去除匹配字符串中的三个反引号
return match.group(0).replace("```", "")
else:
return text
- 提取第一个SQL语句,该语句在
select
之后,忽略大小写,匹配直到第一个分号、三个反引号或字符串末尾,并在提取的字符串中去除三个反引号。
生成SQL方法
def generate_sql(self, question: str, **kwargs) -> str:
sql = super().generate_sql(question, **kwargs) # 调用父类的generate_sql方法
sql = sql.replace("\\_", "_") # 将"\_"替换为"_"
sql = sql.replace("\\", "") # 去除反斜杠
return self.extract_sql_query(sql) # 提取并返回SQL查询
- 生成SQL查询,处理特殊字符并提取SQL语句。
提交提示方法
def submit_prompt(self, prompt, **kwargs) -> str:
url = f"{self.host}/v1/chat/completions"
data = {
"model": self.model,
"stream": False,
"messages": prompt,
}
if self.auth_key is not None:
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.auth_key}'
}
response = requests.post(url, headers=headers, json=data) # 发送带有认证头的POST请求
else:
response = requests.post(url, json=data) # 发送不带认证头的POST请求
response_dict = response.json() # 获取JSON响应
self.log(response.text) # 记录响应日志
return response_dict['choices'][0]['message']['content'] # 返回生成的消息内容
- 提交提示到VLLM服务,处理响应并返回生成的消息内容。