前言
vanna可实现自然语言转SQL,尝试本地部署vanna对接数据库,将自然语言转成标准的SQL对数据库进行查询。本文先对vanna源码进行分析
一、vanna源码分析
chromadb_vector.py:
以下代码是用于管理和操作与 ChromaDB 向量存储相关的数据。该文件定义了ChromaDB_VectorStore 类,该类继承自 VannaBase,用于处理以下任务:
1.初始化 ChromaDB 客户端:根据配置创建持久性或内存中的 ChromaDB 客户端。
2.创建或获取集合:包括 documentation、ddl 和 sql 三个集合,用于存储相应的数据。
3.生成嵌入向量:使用嵌入函数生成数据的向量表示。
4.添加数据:将问题与 SQL、DDL 或文档添加到相应的集合中。
5.获取训练数据:从集合中提取数据并返回 pandas DataFrame。
6.删除数据:根据 ID 从集合中删除特定数据。
7.删除集合:重置集合为空状态。
8.查询相似数据:从集合中查询与问题相关的 SQL、DDL 或文档。
import json # 导入 json 模块,用于处理 JSON 数据
from typing import List # 从 typing 模块导入 List 类型,用于类型注解
import chromadb # 导入 chromadb 模块,用于操作 Chroma 数据库
import pandas as pd # 导入 pandas 模块,用于数据处理
from chromadb.config import Settings # 从 chromadb.config 模块导入 Settings 类
from chromadb.utils import embedding_functions # 从 chromadb.utils 模块导入 embedding_functions
from ..base import VannaBase # 从上一级目录的 base 模块导入 VannaBase 类
from ..utils import deterministic_uuid # 从上一级目录的 utils 模块导入 deterministic_uuid 函数
default_ef = embedding_functions.DefaultEmbeddingFunction() # 获取默认的嵌入函数实例
class ChromaDB_VectorStore(VannaBase): # 定义 ChromaDB_VectorStore 类,继承自 VannaBase
def __init__(self, config=None): # 初始化函数,接受可选的配置参数
VannaBase.__init__(self, config=config) # 调用父类的初始化方法
if config is None: # 如果未提供配置,使用空字典作为默认配置
config = {}
path = config.get("path", ".") # 获取配置中的路径,默认为当前目录
self.embedding_function = config.get("embedding_function", default_ef) # 获取嵌入函数,默认为 default_ef
curr_client = config.get("client", "persistent") # 获取客户端类型,默认为持久化客户端
collection_metadata = config.get("collection_metadata", None) # 获取集合元数据
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10)) # 获取 SQL 查询结果数量,默认为 10
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10)) # 获取文档查询结果数量,默认为 10
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10)) # 获取 DDL 查询结果数量,默认为 10
if curr_client == "persistent": # 如果客户端类型是持久化的
self.chroma_client = chromadb.PersistentClient(
path=path, settings=Settings(anonymized_telemetry=False)
)
elif curr_client == "in-memory": # 如果客户端类型是内存中的
self.chroma_client = chromadb.EphemeralClient(
settings=Settings(anonymized_telemetry=False)
)
elif isinstance(curr_client, chromadb.api.client.Client): # 如果直接提供了客户端实例
self.chroma_client = curr_client # 使用提供的客户端实例
else:
raise ValueError(f"Unsupported client was set in config: {curr_client}") # 如果客户端类型不支持,抛出异常
# 获取或创建 documentation 集合
self.documentation_collection = self.chroma_client.get_or_create_collection(
name="documentation",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
# 获取或创建 ddl 集合
self.ddl_collection = self.chroma_client.get_or_create_collection(
name="ddl",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
# 获取或创建 sql 集合
self.sql_collection = self.chroma_client.get_or_create_collection(
name="sql",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
def generate_embedding(self, data: str, **kwargs) -> List[float]: # 生成嵌入
embedding = self.embedding_function([data]) # 调用嵌入函数生成嵌入
if len(embedding) == 1: # 如果生成的嵌入只有一个
return embedding[0] # 返回该嵌入
return embedding # 否则返回整个嵌入列表
def add_question_sql(self, question: str, sql: str, **kwargs) -> str: # 添加 SQL 问题
question_sql_json = json.dumps(
{
"question": question,
"sql": sql,
},
ensure_ascii=False,
) # 将问题和 SQL 转换为 JSON 字符串
id = deterministic_uuid(question_sql_json) + "-sql" # 生成唯一 ID,并添加 "-sql" 后缀
self.sql_collection.add(
documents=question_sql_json, # 添加文档
embeddings=self.generate_embedding(question_sql_json), # 生成并添加嵌入
ids=id, # 添加 ID
)
return id # 返回 ID
def add_ddl(self, ddl: str, **kwargs) -> str: # 添加 DDL
id = deterministic_uuid(ddl) + "-ddl" # 生成唯一 ID,并添加 "-ddl" 后缀
self.ddl_collection.add(
documents=ddl, # 添加文档
embeddings=self.generate_embedding(ddl), # 生成并添加嵌入
ids=id, # 添加 ID
)
return id # 返回 ID
def add_documentation(self, documentation: str, **kwargs) -> str: # 添加文档
id = deterministic_uuid(documentation) + "-doc" # 生成唯一 ID,并添加 "-doc" 后缀
self.documentation_collection.add(
documents=documentation, # 添加文档
embeddings=self.generate_embedding(documentation), # 生成并添加嵌入
ids=id, # 添加 ID
)
return id # 返回 ID
def get_training_data(self, **kwargs) -> pd.DataFrame: # 获取训练数据
sql_data = self.sql_collection.get() # 获取 SQL 集合中的数据
df = pd.DataFrame() # 创建一个空的 DataFrame
if sql_data is not None: # 如果 SQL 数据不为空
documents = [json.loads(doc) for doc in sql_data["documents"]] # 解析 JSON 文档
ids = sql_data["ids"] # 获取文档 ID
df_sql = pd.DataFrame(
{
"id": ids,
"question": [doc["question"] for doc in documents],
"content": [doc["sql"] for doc in documents],
}
) # 创建 SQL 数据的 DataFrame
df_sql["training_data_type"] = "sql" # 添加数据类型列
df = pd.concat([df, df_sql]) # 将 SQL 数据合并到主 DataFrame
ddl_data = self.ddl_collection.get() # 获取 DDL 集合中的数据
if ddl_data is not None: # 如果 DDL 数据不为空
documents = [doc for doc in ddl_data["documents"]] # 获取文档
ids = ddl_data["ids"] # 获取文档 ID
df_ddl = pd.DataFrame(
{
"id": ids,
"question": [None for doc in documents],
"content": [doc for doc in documents],
}
) # 创建 DDL 数据的 DataFrame
df_ddl["training_data_type"] = "ddl" # 添加数据类型列
df = pd.concat([df, df_ddl]) # 将 DDL 数据合并到主 DataFrame
doc_data = self.documentation_collection.get() # 获取文档集合中的数据
if doc_data is not None: # 如果文档数据不为空
documents = [doc for doc in doc_data["documents"]] # 获取文档
ids = doc_data["ids"] # 获取文档 ID
df_doc = pd.DataFrame(
{
"id": ids,
"question": [None for doc in documents],
"content": [doc for doc in documents],
}
) # 创建文档数据的 DataFrame
df_doc["training_data_type"] = "documentation" # 添加数据类型列
df = pd.concat([df, df_doc]) # 将文档数据合并到主 DataFrame
return df # 返回最终的 DataFrame
def remove_training_data(self, id: str, **kwargs) -> bool: # 删除训练数据
if id.endswith("-sql"): # 如果 ID 以 "-sql" 结尾
self.sql_collection.delete(ids=id) # 删除 SQL 集合中的该 ID
return True
elif id.endswith("-ddl"): # 如果 ID 以 "-ddl" 结尾
self.ddl_collection.delete(ids=id) # 删除 DDL 集合中的该 ID
return True
elif id.endswith("-doc"): # 如果 ID 以 "-doc" 结尾
self.documentation_collection.delete(ids=id) # 删除文档集合中的该 ID
return True
else:
return False # 如果 ID 不符合上述任何条件,返回 False
def remove_collection(self, collection_name: str) -> bool: # 删除集合
"""
This function can reset the collection to empty state.
Args:
collection_name (str): sql or ddl or documentation
Returns:
bool: True if collection is deleted, False otherwise
"""
if collection_name == "sql": # 如果集合名称是 "sql"
self.chroma_client.delete_collection(name="sql") # 删除 sql 集合
self.sql_collection = self.chroma_client.get_or_create_collection(
name="sql", embedding_function=self.embedding_function
) # 重新创建 sql 集合
return True
elif collection_name == "ddl": # 如果集合名称是 "ddl"
self.chroma_client.delete_collection(name="ddl") # 删除 ddl 集合
self.ddl_collection = self.chroma_client.get_or_create_collection(
name="ddl", embedding_function=self.embedding_function
) # 重新创建 ddl 集合
return True
elif collection_name == "documentation": # 如果集合名称是 "documentation"
self.chroma_client.delete_collection(name="documentation") # 删除 documentation 集合
self.documentation_collection = self.chroma_client.get_or_create_collection(
name="documentation", embedding_function=self.embedding_function
) # 重新创建 documentation 集合
return True
else:
return False # 如果集合名称不符合上述任何条件,返回 False
@staticmethod
def _extract_documents(query_results) -> list: # 静态方法:从查询结果中提取文档
"""
Static method to extract the documents from the results of a query.
Args:
query_results (pd.DataFrame): The dataframe to use.
Returns:
List[str] or None: The extracted documents, or an empty list or
single document if an error occurred.
"""
if query_results is None: # 如果查询结果为空
return [] # 返回空列表
if "documents" in query_results: # 如果查询结果包含 "documents" 字段
documents = query_results["documents"]
if len(documents) == 1 and isinstance(documents[0], list): # 如果只有一个文档并且是列表类型
try:
documents = [json.loads(doc) for doc in documents[0]] # 尝试解析 JSON 文档
except Exception as e: # 如果解析失败
return documents[0] # 返回原始文档
return documents # 返回文档列表
def get_similar_question_sql(self, question: str, **kwargs) -> list: # 获取类似问题的 SQL
return ChromaDB_VectorStore._extract_documents(
self.sql_collection.query(
query_texts=[question], # 查询文本
n_results=self.n_results_sql, # 查询结果数量
)
)
def get_related_ddl(self, question: str, **kwargs) -> list: # 获取相关的 DDL
return ChromaDB_VectorStore._extract_documents(
self.ddl_collection.query(
query_texts=[question], # 查询文本
n_results=self.n_results_ddl, # 查询结果数量
)
)
def get_related_documentation(self, question: str, **kwargs) -> list: # 获取相关的文档
return ChromaDB_VectorStore._extract_documents(
self.documentation_collection.query(
query_texts=[question], # 查询文本
n_results=self.n_results_documentation, # 查询结果数量
)
)