文章目录
前言
vanna可实现自然语言转SQL,尝试本地部署vanna对接数据库,将自然语言转成标准的SQL对数据库进行查询。本文先对vanna源码进行分析,该部分内容为接入各类ai方法。
一、vanna源码分析
gemini_chat.py
这段代码定义了一个 GoogleGeminiChat
类,继承自 VannaBase
。
- 这个类主要用于与 Google 的生成式 AI 模型进行交互,特别是用于聊天应用。
- 通过配置 API 密钥和模型名称,可以灵活地使用不同的生成模型。
- 提供了简单的消息处理方法,可以根据需要进行扩展。
import os
from ..base import VannaBase # 从父模块导入 VannaBase 类
class GoogleGeminiChat(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config) # 调用父类的构造函数
# 默认的温度值,可以通过 config 覆盖
self.temperature = 0.7
# 如果配置中包含 temperature,则覆盖默认值
if "temperature" in config:
self.temperature = config["temperature"]
# 设置模型名称,如果配置中没有指定,则使用默认值 "gemini-1.0-pro"
if "model_name" in config:
model_name = config["model_name"]
else:
model_name = "gemini-1.0-pro"
self.google_api_key = None # 初始化 API 密钥变量
# 如果配置中提供了 api_key 或环境变量中有 GOOGLE_API_KEY,则使用它
if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
"""
如果 Google api_key 通过配置提供
或设置为环境变量,则分配它。
"""
import google.generativeai as genai # 导入 google.generativeai 库
genai.configure(api_key=config["api_key"]) # 使用提供的 API 密钥进行配置
self.chat_model = genai.GenerativeModel(model_name) # 初始化生成模型
else:
# 使用 VertexAI 进行身份验证
from vertexai.preview.generative_models import GenerativeModel # 导入 VertexAI 的生成模型类
self.chat_model = GenerativeModel("gemini-pro") # 初始化生成模型
def system_message(self, message: str) -> any:
return message # 返回系统消息
def user_message(self, message: str) -> any:
return message # 返回用户消息
def assistant_message(self, message: str) -> any:
return message # 返回助手消息
def submit_prompt(self, prompt, **kwargs) -> str:
# 使用生成模型生成内容
response = self.chat_model.generate_content(
prompt,
generation_config={
"temperature": self.temperature, # 使用配置的温度值
},
)
return response.text # 返回生成的文本
功能和作用
-
类的初始化:
- 初始化时调用父类
VannaBase
的构造函数。 - 设置默认的温度值
self.temperature
,可以通过配置覆盖。 - 设置模型名称
model_name
,可以通过配置覆盖,默认使用"gemini-1.0-pro"
。 - 检查是否提供了 API 密钥,如果提供了,则使用
google.generativeai
库进行配置并初始化生成模型。如果没有提供 API 密钥,则使用 VertexAI 进行身份验证并初始化生成模型。
- 初始化时调用父类
-
消息处理方法:
system_message
、user_message
和assistant_message
方法都是简单地返回传入的消息。这些方法可以在实际应用中进行扩展,以处理不同类型的消息。
-
提交提示:
submit_prompt
方法使用生成模型生成内容。它接受一个提示(prompt)并生成相应的文本。生成的配置包括温度值。
hf
hf.py
这段代码定义了一个 Hf
类,继承自 VannaBase
。
- 这个类主要用于与 Hugging Face 的生成模型进行交互,特别是用于生成 SQL 查询。
- 通过配置模型名称,可以灵活地使用不同的生成模型。
- 提供了简单的消息处理方法和 SQL 提取方法,可以根据需要进行扩展。
import re # 导入正则表达式模块
from transformers import AutoTokenizer, AutoModelForCausalLM # 从 transformers 库中导入自动分词器和因果语言模型
from ..base import VannaBase # 从父模块导入 VannaBase 类
class Hf(VannaBase):
def __init__(self, config=None):
# 从配置中获取模型名称,例如 "meta-llama/Meta-Llama-3-8B-Instruct"
model_name = self.config.get("model_name", None)
# 从预训练模型中加载分词器
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# 从预训练模型中加载因果语言模型,设置数据类型和设备映射为自动
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)
def system_message(self, message: str) -> any:
# 返回包含角色和内容的字典,表示系统消息
return {"role": "system", "content": message}
def user_message(self, message: str) -> any:
# 返回包含角色和内容的字典,表示用户消息
return {"role": "user", "content": message}
def assistant_message(self, message: str) -> any:
# 返回包含角色和内容的字典,表示助手消息
return {"role": "assistant", "content": message}
def extract_sql_query(self, text):
"""
提取第一个 SQL 语句,从 'select' 开始,不区分大小写,
匹配到第一个分号、三个反引号或字符串结尾,
并移除提取字符串中的三个反引号(如果存在)。
参数:
- text (str): 要搜索 SQL 语句的字符串。
返回:
- str: 找到的第一个 SQL 语句,移除三个反引号后的结果,如果没有匹配则返回空字符串。
"""
# 正则表达式模式,用于查找 'select'(忽略大小写)并捕获到分号、三个反引号或字符串结尾
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)
# 搜索匹配项
match = pattern.search(text)
if match:
# 如果存在匹配项,移除匹配字符串中的三个反引号
return match.group(0).replace("```", "")
else:
# 如果没有匹配项,返回输入字符串
return text
def generate_sql(self, question: str, **kwargs) -> str:
# 使用父类的 generate_sql 方法生成 SQL 语句
sql = super().generate_sql(question, **kwargs)
# 替换字符串中的 "\_" 为 "_"
sql = sql.replace("\\_", "_")
# 替换字符串中的 "\" 为 ""
sql = sql.replace("\\", "")
# 提取并返回 SQL 查询
return self.extract_sql_query(sql)
def submit_prompt(self, prompt, **kwargs) -> str:
# 使用分词器对提示进行处理,添加生成提示并返回张量格式
input_ids = self.tokenizer.apply_chat_template(
prompt, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
# 使用模型生成输出,设置最大新标记数、结束标记 ID、是否进行采样、温度和 top-p 参数
outputs = self.model.generate(
input_ids,
max_new_tokens=512,
eos_token_id=self.tokenizer.eos_token_id,
do_sample=True,
temperature=1,
top_p=0.9,
)
# 提取生成的响应,跳过输入部分
response = outputs[0][input_ids.shape[-1] :]
# 解码生成的响应,跳过特殊标记
response = self.tokenizer.decode(response, skip_special_tokens=True)
# 记录响应
self.log(response)
# 返回响应
return response
功能和作用
-
类的初始化:
- 初始化时从配置中获取模型名称,并加载相应的分词器和因果语言模型。
- 设置模型的
torch_dtype
和device_map
为auto
,以便自动调整数据类型和设备。
-
消息处理方法:
system_message
、user_message
和assistant_message
方法返回包含角色和内容的字典,用于表示不同类型的消息。
-
提取 SQL 查询:
extract_sql_query
方法使用正则表达式从输入文本中提取第一个 SQL 语句,匹配到分号、三个反引号或字符串结尾,并移除提取字符串中的三个反引号。
-
生成 SQL 查询:
generate_sql
方法首先调用父类的generate_sql
方法生成 SQL 语句,然后替换字符串中的特定字符,并使用extract_sql_query
方法提取 SQL 查询。
-
提交提示:
submit_prompt
方法使用分词器对提示进行处理,并使用模型生成响应。生成的响应经过解码和记录后返回。
marqo
marqo.py
- 这个类主要用于与 Marqo 向量存储进行交互,管理和搜索与 SQL 查询、DDL 以及文档相关的数据。
- 提供了一套方法,用于添加、检索和删除训练数据,并支持通过相似性搜索获取相关的 SQL、DDL 和文档。
import uuid # 导入用于生成唯一标识符的模块
import marqo # 导入 Marqo 库,用于处理向量存储
import pandas as pd # 导入 pandas 库,用于数据处理
from ..base import VannaBase # 从父模块导入 VannaBase 类
class Marqo_VectorStore(VannaBase):
def __init__(self, config=None):
# 调用父类的构造函数进行初始化
VannaBase.__init__(self, config=config)
# 检查配置中是否包含 marqo_url,否则使用默认值
if config is not None and "marqo_url" in config:
marqo_url = config["marqo_url"]
else:
marqo_url = "http://localhost:8882"
# 检查配置中是否包含 marqo_model,否则使用默认模型名称
if config is not None and "marqo_model" in config:
marqo_model = config["marqo_model"]
else:
marqo_model = "hf/all_datasets_v4_MiniLM-L6"
# 创建 Marqo 客户端
self.mq = marqo.Client(url=marqo_url)
# 创建三个索引:vanna-sql、vanna-ddl、vanna-doc
for index in ["vanna-sql", "vanna-ddl", "vanna-doc"]:
try:
# 尝试创建索引
self.mq.create_index(index, model=marqo_model)
except Exception as e:
# 如果索引已经存在,则捕获异常并打印错误信息
print(e)
print(f"Marqo index {index} already exists")
pass
def generate_embedding(self, data: str, **kwargs) -> list[float]:
# Marqo 不需要生成嵌入
pass
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
# 生成唯一标识符并添加 "-sql" 后缀
id = str(uuid.uuid4()) + "-sql"
# 创建包含问题和 SQL 的字典
question_sql_dict = {
"question": question,
"sql": sql,
"_id": id,
}
# 将文档添加到 "vanna-sql" 索引中
self.mq.index("vanna-sql").add_documents(
[question_sql_dict],
tensor_fields=["question", "sql"],
)
return id
def add_ddl(self, ddl: str, **kwargs) -> str:
# 生成唯一标识符并添加 "-ddl" 后缀
id = str(uuid.uuid4()) + "-ddl"
# 创建包含 DDL 的字典
ddl_dict = {
"ddl": ddl,
"_id": id,
}
# 将文档添加到 "vanna-ddl" 索引中
self.mq.index("vanna-ddl").add_documents(
[ddl_dict],
tensor_fields=["ddl"],
)
return id
def add_documentation(self, documentation: str, **kwargs) -> str:
# 生成唯一标识符并添加 "-doc" 后缀
id = str(uuid.uuid4()) + "-doc"
# 创建包含文档的字典
doc_dict = {
"doc": documentation,
"_id": id,
}
# 将文档添加到 "vanna-doc" 索引中
self.mq.index("vanna-doc").add_documents(
[doc_dict],
tensor_fields=["doc"],
)
return id
def get_training_data(self, **kwargs) -> pd.DataFrame:
data = [] # 初始化一个空列表用于存储数据
# 从 "vanna-doc" 索引中检索文档
for hit in self.mq.index("vanna-doc").search("", limit=1000)["hits"]:
data.append(
{
"id": hit["_id"],
"training_data_type": "documentation",
"question": "",
"content": hit["doc"],
}
)
# 从 "vanna-ddl" 索引中检索文档
for hit in self.mq.index("vanna-ddl").search("", limit=1000)["hits"]:
data.append(
{
"id": hit["_id"],
"training_data_type": "ddl",
"question": "",
"content": hit["ddl"],
}
)
# 从 "vanna-sql" 索引中检索文档
for hit in self.mq.index("vanna-sql").search("", limit=1000)["hits"]:
data.append(
{
"id": hit["_id"],
"training_data_type": "sql",
"question": hit["question"],
"content": hit["sql"],
}
)
# 将数据转换为 DataFrame 并返回
df = pd.DataFrame(data)
return df
def remove_training_data(self, id: str, **kwargs) -> bool:
# 根据 ID 后缀确定要删除的索引中的文档
if id.endswith("-sql"):
self.mq.index("vanna-sql").delete_documents(ids=[id])
return True
elif id.endswith("-ddl"):
self.mq.index("vanna-ddl").delete_documents(ids=[id])
return True
elif id.endswith("-doc"):
self.mq.index("vanna-doc").delete_documents(ids=[id])
return True
else:
return False
@staticmethod
def _extract_documents(data) -> list:
# 检查数据中是否包含 'hits' 键且其值是否为列表
if "hits" in data and isinstance(data["hits"], list):
# 如果 'hits' 列表为空,则返回空列表
if len(data["hits"]) == 0:
return []
# 如果 'hits' 中包含 "doc" 键,则返回其值
if "doc" in data["hits"][0]:
return [hit["doc"] for hit in data["hits"]]
# 如果 'hits' 中包含 "ddl" 键,则返回其值
if "ddl" in data["hits"][0]:
return [hit["ddl"] for hit in data["hits"]]
# 否则,返回所有命中的项目
return [
{key: value for key, value in hit.items() if not key.startswith("_")}
for hit in data["hits"]
]
else:
# 如果 'hits' 不存在或不是列表,则返回空列表
return []
def get_similar_question_sql(self, question: str, **kwargs) -> list:
# 从 "vanna-sql" 索引中搜索相似的问题 SQL,并提取文档
return Marqo_VectorStore._extract_documents(
self.mq.index("vanna-sql").search(question)
)
def get_related_ddl(self, question: str, **kwargs) -> list:
# 从 "vanna-ddl" 索引中搜索相关的 DDL,并提取文档
return Marqo_VectorStore._extract_documents(
self.mq.index("vanna-ddl").search(question)
)
def get_related_documentation(self, question: str, **kwargs) -> list:
# 从 "vanna-doc" 索引中搜索相关的文档,并提取文档
return Marqo_VectorStore._extract_documents(
self.mq.index("vanna-doc").search(question)
)
功能和作用
-
类的初始化:
- 初始化时调用父类的构造函数。
- 从配置中读取
marqo_url
和marqo_model
,如果未提供则使用默认值。 - 创建 Marqo 客户端并尝试创建三个索引:
vanna-sql
、vanna-ddl
、vanna-doc
。
-
生成嵌入:
generate_embedding
方法目前没有实现,因为 Marqo 不需要生成嵌入。
-
添加问题和 SQL:
add_question_sql
方法生成一个唯一标识符,并将问题和 SQL 添加到vanna-sql
索引中。
-
添加 DDL:
add_ddl
方法生成一个唯一标识符,并将 DDL 添加到vanna-ddl
索引中。
-
添加文档:
add_documentation
方法生成一个唯一标识符,并将文档添加到vanna-doc
索引中。
-
获取训练数据:
get_training_data
方法从三个索引中检索文档并转换为 pandas DataFrame。
-
删除训练数据:
remove_training_data
方法根据文档 ID 后缀确定要删除的索引中的文档。
-
静态方法提取文档:
_extract_documents
静态方法从搜索结果中提取文档。
-
获取相似问题的 SQL:
get_similar_question_sql
方法从vanna-sql
索引中搜索相似的问题 SQL,并提取文档。
-
获取相关的 DDL:
get_related_ddl
方法从vanna-ddl
索引中搜索相关的 DDL,并提取文档。
- 获取相关文档:
get_related_documentation
方法从vanna-doc
索引中搜索相关的文档,并提取文档。
mistral
mistral.py
- 这个类主要用于与 Mistral AI 服务进行交互,处理聊天消息并生成相应的 SQL 查询。
- 提供了一套方法,用于构建系统消息、用户消息和助手消息,并支持发送聊天请求和处理响应。
from mistralai.client import MistralClient # 从 Mistral 库导入 MistralClient
from mistralai.models.chat_completion import ChatMessage # 从 Mistral 库导入 ChatMessage 模型
from ..base import VannaBase # 从父模块导入 VannaBase 类
class Mistral(VannaBase):
def __init__(self, config=None):
# 如果没有提供配置,抛出 ValueError 异常
if config is None:
raise ValueError(
"For Mistral, config must be provided with an api_key and model"
)
# 如果配置中不包含 api_key,抛出 ValueError 异常
if "api_key" not in config:
raise ValueError("config must contain a Mistral api_key")
# 如果配置中不包含 model,抛出 ValueError 异常
if "model" not in config:
raise ValueError("config must contain a Mistral model")
# 从配置中获取 api_key 和 model
api_key = config["api_key"]
model = config["model"]
# 创建 Mistral 客户端实例
self.client = MistralClient(api_key=api_key)
# 设置模型名称
self.model = model
def system_message(self, message: str) -> any:
# 返回一个系统消息对象
return ChatMessage(role="system", content=message)
def user_message(self, message: str) -> any:
# 返回一个用户消息对象
return ChatMessage(role="user", content=message)
def assistant_message(self, message: str) -> any:
# 返回一个助手消息对象
return ChatMessage(role="assistant", content=message)
def generate_sql(self, question: str, **kwargs) -> str:
# 使用父类的方法生成 SQL 查询
sql = super().generate_sql(question, **kwargs)
# 将 "\_" 替换为 "_"
sql = sql.replace("\\_", "_")
return sql
def submit_prompt(self, prompt, **kwargs) -> str:
# 使用 Mistral 客户端发送聊天请求
chat_response = self.client.chat(
model=self.model,
messages=prompt,
)
# 返回聊天响应中的消息内容
return chat_response.choices[0].message.content
功能和作用
-
类的初始化:
- 初始化时检查配置是否包含
api_key
和model
,如果缺少任意一个则抛出ValueError
异常。 - 创建 Mistral 客户端实例,并设置模型名称。
- 初始化时检查配置是否包含
-
系统消息:
system_message
方法创建并返回一个系统消息对象。
-
用户消息:
user_message
方法创建并返回一个用户消息对象。
-
助手消息:
assistant_message
方法创建并返回一个助手消息对象。
-
生成 SQL 查询:
generate_sql
方法调用父类的方法生成 SQL 查询,然后替换其中的 “_” 为 “_” 并返回最终的 SQL 查询。
-
提交提示:
submit_prompt
方法使用 Mistral 客户端发送聊天请求,并返回聊天响应中的消息内容。
mock
embedding.py
这段代码定义了一个名为 MockEmbedding
的类,继承自 VannaBase
。
MockEmbedding
类包含一个构造函数和一个generate_embedding
方法,后者返回一个固定的浮点数列表。- 该类主要用于测试和开发阶段,提供一个简单的嵌入生成实现。
from typing import List # 从 typing 模块导入 List 类型,用于类型注解
from ..base import VannaBase # 从上一级目录的 base 模块导入 VannaBase 类
class MockEmbedding(VannaBase): # 定义一个名为 MockEmbedding 的类,继承自 VannaBase
def __init__(self, config=None): # 定义类的构造函数,接受一个可选的配置参数 config
pass # 目前构造函数不做任何操作
def generate_embedding(self, data: str, **kwargs) -> List[float]: # 定义一个名为 generate_embedding 的方法,接受一个字符串参数 data 和其他可选参数,返回一个浮点数列表
return [1.0, 2.0, 3.0, 4.0, 5.0] # 返回一个固定的浮点数列表
功能和作用
-
导入
List
类型:- 从
typing
模块导入List
类型,用于类型注解,指定generate_embedding
方法返回值的类型。
- 从
-
导入
VannaBase
类:- 从上一级目录的
base
模块导入VannaBase
类。VannaBase
类可能是所有具体实现的基类,提供一些基本的功能和接口。
- 从上一级目录的
-
定义
MockEmbedding
类:MockEmbedding
类继承自VannaBase
,用于模拟嵌入生成的功能,通常在测试或开发阶段使用。
-
构造函数
__init__
:- 定义类的构造函数,接受一个可选的配置参数
config
。目前构造函数不做任何实际操作,仅包含一个pass
语句。
- 定义类的构造函数,接受一个可选的配置参数
-
generate_embedding
方法:- 定义一个名为
generate_embedding
的方法,接受一个字符串参数data
和其他可选参数,返回一个浮点数列表。 - 方法的实现简单地返回一个固定的浮点数列表
[1.0, 2.0, 3.0, 4.0, 5.0]
,模拟生成的嵌入向量。
- 定义一个名为
llm.py
这段代码定义了一个名为 MockLLM
的类,继承自 VannaBase
。
MockLLM
类包含一个构造函数和多个消息创建方法,以及一个提交提示的方法,后者返回一个固定的字符串响应。- 该类主要用于测试和开发阶段,提供一个简单的大语言模型模拟实现。
from ..base import VannaBase # 从上一级目录的 base 模块导入 VannaBase 类
class MockLLM(VannaBase): # 定义一个名为 MockLLM 的类,继承自 VannaBase
def __init__(self, config=None): # 定义类的构造函数,接受一个可选的配置参数 config
pass # 目前构造函数不做任何操作
def system_message(self, message: str) -> any: # 定义一个名为 system_message 的方法,接受一个字符串参数 message,返回一个任意类型的值
return {"role": "system", "content": message} # 返回一个包含角色和内容的字典,角色为 "system"
def user_message(self, message: str) -> any: # 定义一个名为 user_message 的方法,接受一个字符串参数 message,返回一个任意类型的值
return {"role": "user", "content": message} # 返回一个包含角色和内容的字典,角色为 "user"
def assistant_message(self, message: str) -> any: # 定义一个名为 assistant_message 的方法,接受一个字符串参数 message,返回一个任意类型的值
return {"role": "assistant", "content": message} # 返回一个包含角色和内容的字典,角色为 "assistant"
def submit_prompt(self, prompt, **kwargs) -> str: # 定义一个名为 submit_prompt 的方法,接受一个参数 prompt 和其他可选参数,返回一个字符串
return "Mock LLM response" # 返回一个固定的字符串 "Mock LLM response"
代码功能和作用
-
导入
VannaBase
类:- 从上一级目录的
base
模块导入VannaBase
类。VannaBase
类可能是所有具体实现的基类,提供一些基本的功能和接口。
- 从上一级目录的
-
定义
MockLLM
类:MockLLM
类继承自VannaBase
,用于模拟大语言模型(LLM)的功能,通常在测试或开发阶段使用。
-
构造函数
__init__
:- 定义类的构造函数,接受一个可选的配置参数
config
。目前构造函数不做任何实际操作,仅包含一个pass
语句。
- 定义类的构造函数,接受一个可选的配置参数
-
system_message
方法:- 定义一个名为
system_message
的方法,接受一个字符串参数message
,返回一个包含角色和内容的字典,角色为 “system”。 - 该方法用于创建系统消息的结构。
- 定义一个名为
-
user_message
方法:- 定义一个名为
user_message
的方法,接受一个字符串参数message
,返回一个包含角色和内容的字典,角色为 “user”。 - 该方法用于创建用户消息的结构。
- 定义一个名为
-
assistant_message
方法:- 定义一个名为
assistant_message
的方法,接受一个字符串参数message
,返回一个包含角色和内容的字典,角色为 “assistant”。 - 该方法用于创建助手消息的结构。
- 定义一个名为
-
submit_prompt
方法:- 定义一个名为
submit_prompt
的方法,接受一个参数prompt
和其他可选参数,返回一个字符串。 - 该方法返回一个固定的字符串 “Mock LLM response”,模拟大语言模型的响应。
- 定义一个名为
vectordb.py
MockVectorDB
类模拟了一个简单的向量数据库,实现了一些基本的数据库操作和数据处理功能。该类主要用于测试和开发阶段,提供一个简单的实现,而无需依赖实际的数据库操作。
- 数据处理:通过
pandas
DataFrame 返回训练数据,便于数据的处理和操作。 - 接口实现:实现了一些基本的接口方法,如获取相关的 DDL、文档和类似问题及 SQL,这些方法在测试和开发阶段可以返回空值或固定值。
import pandas as pd # 导入 pandas 库,用于处理数据
from ..base import VannaBase # 从上一级目录的 base 模块导入 VannaBase 类
class MockVectorDB(VannaBase): # 定义一个名为 MockVectorDB 的类,继承自 VannaBase
def __init__(self, config=None): # 定义类的构造函数,接受一个可选的配置参数 config
pass # 目前构造函数不做任何操作
def _get_id(self, value: str, **kwargs) -> str: # 定义一个私有方法 _get_id,接受一个字符串参数 value 和其他可选参数
# 将值进行哈希处理并返回 ID
return str(hash(value)) # 返回值的哈希值转换成字符串形式
def add_ddl(self, ddl: str, **kwargs) -> str: # 定义一个方法 add_ddl,接受一个字符串参数 ddl 和其他可选参数
return self._get_id(ddl) # 调用 _get_id 方法并返回其结果
def add_documentation(self, doc: str, **kwargs) -> str: # 定义一个方法 add_documentation,接受一个字符串参数 doc 和其他可选参数
return self._get_id(doc) # 调用 _get_id 方法并返回其结果
def add_question_sql(self, question: str, sql: str, **kwargs) -> str: # 定义一个方法 add_question_sql,接受一个字符串参数 question 和 sql 以及其他可选参数
return self._get_id(question) # 调用 _get_id 方法并返回其结果
def get_related_ddl(self, question: str, **kwargs) -> list: # 定义一个方法 get_related_ddl,接受一个字符串参数 question 和其他可选参数
return [] # 返回一个空列表
def get_related_documentation(self, question: str, **kwargs) -> list: # 定义一个方法 get_related_documentation,接受一个字符串参数 question 和其他可选参数
return [] # 返回一个空列表
def get_similar_question_sql(self, question: str, **kwargs) -> list: # 定义一个方法 get_similar_question_sql,接受一个字符串参数 question 和其他可选参数
return [] # 返回一个空列表
def get_training_data(self, **kwargs) -> pd.DataFrame: # 定义一个方法 get_training_data,接受其他可选参数,返回一个 pandas DataFrame
# 返回一个包含训练数据的 DataFrame
return pd.DataFrame({'id': {0: '19546-ddl', # 训练数据 ID
1: '91597-sql',
2: '133976-sql',
3: '59851-doc',
4: '73046-sql'},
'training_data_type': {0: 'ddl', # 训练数据类型
1: 'sql',
2: 'sql',
3: 'documentation',
4: 'sql'},
'question': {0: None, # 问题
1: 'What are the top selling genres?',
2: 'What are the low 7 artists by sales?',
3: None,
4: 'What is the total sales for each customer?'},
'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)', # 内容
1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;',
2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;',
3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.',
4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}})
def remove_training_data(id: str, **kwargs) -> bool: # 定义一个方法 remove_training_data,接受一个字符串参数 id 和其他可选参数
return True # 返回 True,表示删除成功
代码功能和作用
-
导入模块:
pandas
:用于数据处理和操作。VannaBase
:从上一级目录的base
模块导入VannaBase
类,作为基类。
-
定义
MockVectorDB
类:- 继承自
VannaBase
,模拟一个向量数据库的基本操作。
- 继承自
-
构造函数
__init__
:- 定义类的构造函数,目前不做任何实际操作,仅包含一个
pass
语句。
- 定义类的构造函数,目前不做任何实际操作,仅包含一个
-
_get_id
方法:- 私有方法
_get_id
,接受一个字符串参数value
,返回该字符串的哈希值作为 ID。
- 私有方法
-
add_ddl
方法:- 接受一个 DDL 语句字符串,调用
_get_id
方法返回其哈希值作为 ID。
- 接受一个 DDL 语句字符串,调用
-
add_documentation
方法:- 接受一个文档字符串,调用
_get_id
方法返回其哈希值作为 ID。
- 接受一个文档字符串,调用
-
add_question_sql
方法:- 接受一个问题和对应的 SQL 语句,调用
_get_id
方法返回问题的哈希值作为 ID。
- 接受一个问题和对应的 SQL 语句,调用
-
get_related_ddl
方法:- 接受一个问题字符串,返回一个空列表,表示没有相关的 DDL。
-
get_related_documentation
方法:- 接受一个问题字符串,返回一个空列表,表示没有相关的文档。
-
get_similar_question_sql
方法:- 接受一个问题字符串,返回一个空列表,表示没有类似的问题和 SQL。
-
get_training_data
方法:- 返回一个包含训练数据的 pandas DataFrame,其中包含 ID、训练数据类型、问题和内容。
-
remove_training_data
方法:- 接受一个数据 ID,返回
True
,表示成功删除训练数据。
- 接受一个数据 ID,返回
openai
openai_chat.py
代码用途
OpenAI_Chat
类实现了与 OpenAI 模型的基本接口,提供了初始化、消息处理和提示提交等功能。该类用于与 OpenAI 模型进行通信,发送提示并接收响应,同时处理和返回响应中的文本内容。
- 与 OpenAI 模型的通信:
OpenAI_Chat
类实现了与 OpenAI 模型的接口,通过 HTTP 请求与 OpenAI 模型进行通信,发送提示并接收响应。 - 消息处理:提供系统消息、用户消息和助手消息的方法,用于生成与 OpenAI 模型通信的消息格式。
- 提示提交和响应处理:检查和验证提示内容,计算 token 数量,并调用 OpenAI API 生成响应,同时处理和返回响应中的文本内容。
import os # 导入 os 模块,用于与操作系统交互
from openai import OpenAI # 从 openai 模块导入 OpenAI 类
from ..base import VannaBase # 从上一级目录的 base 模块导入 VannaBase 类
class OpenAI_Chat(VannaBase): # 定义一个名为 OpenAI_Chat 的类,继承自 VannaBase
def __init__(self, client=None, config=None): # 定义类的构造函数,接受可选参数 client 和 config
VannaBase.__init__(self, config=config) # 调用父类 VannaBase 的构造函数
# 设置默认参数 - 可以通过 config 覆盖
self.temperature = 0.7
self.max_tokens = 500
if "temperature" in config:
self.temperature = config["temperature"] # 如果 config 中包含 "temperature",则覆盖默认温度值
if "max_tokens" in config:
self.max_tokens = config["max_tokens"] # 如果 config 中包含 "max_tokens",则覆盖默认最大 token 数
if "api_type" in config:
raise Exception(
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
) # 如果 config 中包含 "api_type",抛出异常
if "api_base" in config:
raise Exception(
"Passing api_base is now deprecated. Please pass an OpenAI client instead."
) # 如果 config 中包含 "api_base",抛出异常
if "api_version" in config:
raise Exception(
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
) # 如果 config 中包含 "api_version",抛出异常
if client is not None:
self.client = client # 如果提供了 client 参数,则将其设置为实例属性
return
if config is None and client is None:
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # 如果没有提供 config 和 client,从环境变量中获取 API 密钥并创建 OpenAI 客户端
return
if "api_key" in config:
self.client = OpenAI(api_key=config["api_key"]) # 如果 config 中包含 "api_key",使用该密钥创建 OpenAI 客户端
def system_message(self, message: str) -> any: # 定义一个方法 system_message,接受一个字符串参数 message
return {"role": "system", "content": message} # 返回一个包含角色和内容的字典
def user_message(self, message: str) -> any: # 定义一个方法 user_message,接受一个字符串参数 message
return {"role": "user", "content": message} # 返回一个包含角色和内容的字典
def assistant_message(self, message: str) -> any: # 定义一个方法 assistant_message,接受一个字符串参数 message
return {"role": "assistant", "content": message} # 返回一个包含角色和内容的字典
def submit_prompt(self, prompt, **kwargs) -> str: # 定义一个方法 submit_prompt,接受一个参数 prompt 和其他可选参数
if prompt is None:
raise Exception("Prompt is None") # 如果 prompt 为 None,抛出异常
if len(prompt) == 0:
raise Exception("Prompt is empty") # 如果 prompt 为空,抛出异常
# 计算消息日志中的 token 数量
# 使用 4 作为每个 token 的近似字符数
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
if kwargs.get("model", None) is not None:
model = kwargs.get("model", None) # 如果 kwargs 中提供了 model 参数,则使用该模型
print(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
engine = kwargs.get("engine", None) # 如果 kwargs 中提供了 engine 参数,则使用该引擎
print(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=engine,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=self.config["engine"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
else:
if num_tokens > 3500:
model = "gpt-3.5-turbo-16k" # 如果 token 数量超过 3500,使用 gpt-3.5-turbo-16k 模型
else:
model = "gpt-3.5-turbo" # 否则使用 gpt-3.5-turbo 模型
print(f"Using model {model} for {num_tokens} tokens (approx)")
response = self.client.chat.completions.create(
model=model,
messages=prompt,
max_tokens=self.max_tokens,
stop=None,
temperature=self.temperature,
)
# 查找包含文本的第一个响应(有些响应可能没有文本)
for choice in response.choices:
if "text" in choice:
return choice.text
# 如果没有找到包含文本的响应,返回第一个响应的内容(可能为空)
return response.choices[0].message.content
功能和作用
-
导入模块:
os
:用于与操作系统交互,特别是获取环境变量。OpenAI
:用于与 OpenAI API 进行交互。VannaBase
和DependencyError
:用于继承基类和处理依赖错误。
-
定义
OpenAI_Chat
类:- 继承自
VannaBase
,实现了与 OpenAI 模型的接口。
- 继承自
-
构造函数
__init__
:- 调用父类的构造函数并初始化一些默认参数。
- 根据配置覆盖默认参数。
- 检查并设置 OpenAI 客户端。
-
消息方法:
system_message
、user_message
和assistant_message
方法分别返回带有角色和内容的字典,用于与 OpenAI 模型的通信。
-
submit_prompt
方法:- 检查和验证提示内容。
- 计算消息日志中的 token 数量。
- 根据不同的配置和参数,调用 OpenAI API 生成响应。
- 查找并返回包含文本的第一个响应。
openai_embeddings.py
这个类的主要功能是初始化 OpenAI 客户端并根据提供的配置生成嵌入向量。
from openai import OpenAI # 从 openai 库导入 OpenAI 类,用于与 OpenAI API 交互
from ..base import VannaBase # 从本地模块导入 VannaBase 类,这是一个基础类
# 定义一个新的类 OpenAI_Embeddings,它继承自 VannaBase
class OpenAI_Embeddings(VannaBase):
def __init__(self, client=None, config=None):
# 调用父类 VannaBase 的构造函数,并传递配置参数
VannaBase.__init__(self, config=config)
# 如果提供了 client 参数,则使用这个参数初始化 self.client
if client is not None:
self.client = client
return # 结束构造函数的执行
# 如果 self.client 已经存在(通过父类初始化),则结束构造函数
if self.client is not None:
return
# 如果没有提供 client 参数且 self.client 尚未初始化,则创建一个新的 OpenAI 客户端
self.client = OpenAI()
# 如果没有提供配置参数,则结束构造函数
if config is None:
return
# 根据配置参数设置 OpenAI 客户端的不同属性
if "api_type" in config:
self.client.api_type = config["api_type"] # 设置 API 类型
if "api_base" in config:
self.client.api_base = config["api_base"] # 设置 API 基础 URL
if "api_version" in config:
self.client.api_version = config["api_version"] # 设置 API 版本
if "api_key" in config:
self.client.api_key = config["api_key"] # 设置 API 密钥
# 定义生成嵌入的方法,接受输入数据 data 和其他可选参数 kwargs
def generate_embedding(self, data: str, **kwargs) -> list[float]:
# 如果配置参数中包含引擎配置,则使用该引擎生成嵌入
if self.config is not None and "engine" in self.config:
embedding = self.client.embeddings.create(
engine=self.config["engine"], # 使用指定的引擎
input=data, # 输入数据
)
else:
# 否则使用默认的模型生成嵌入
embedding = self.client.embeddings.create(
model="text-embedding-ada-002", # 使用默认模型
input=data, # 输入数据
)
# 返回生成的嵌入向量
return embedding.get("data")[0]["embedding"]
opensearch
opensearch_vector.py
OpenSearch_VectorStore 类用于与 OpenSearch 进行交互,提供存储和检索文档、DDL(数据定义语言)和问题-SQL 对的方法。
import base64 # 导入 base64 模块,用于处理 base64 编码
import uuid # 导入 uuid 模块,用于生成唯一标识符
from typing import List # 导入 List 类型提示
import pandas as pd # 导入 pandas 模块,用于数据处理
from opensearchpy import OpenSearch # 导入 OpenSearch 模块,用于与 OpenSearch 服务交互
from ..base import VannaBase # 从本地模块导入 VannaBase 类,这是一个基础类
# 定义 OpenSearch_VectorStore 类,继承自 VannaBase
class OpenSearch_VectorStore(VannaBase):
def __init__(self, config=None):
# 调用父类 VannaBase 的构造函数,并传递配置参数
VannaBase.__init__(self, config=config)
# 初始化索引名称
document_index = "vanna_document_index"
ddl_index = "vanna_ddl_index"
question_sql_index = "vanna_questions_sql_index"
# 从配置中获取自定义索引名称
if config is not None and "es_document_index" in config:
document_index = config["es_document_index"]
if config is not None and "es_ddl_index" in config:
ddl_index = config["es_ddl_index"]
if config is not None and "es_question_sql_index" in config:
question_sql_index = config["es_question_sql_index"]
# 将索引名称保存为类的属性
self.document_index = document_index
self.ddl_index = ddl_index
self.question_sql_index = question_sql_index
print("OpenSearch_VectorStore initialized with document_index: ",
document_index, " ddl_index: ", ddl_index, " question_sql_index: ",
question_sql_index)
# 定义默认索引设置
document_index_settings = {
"settings": {
"index": {
"number_of_shards": 6,
"number_of_replicas": 2
}
},
"mappings": {
"properties": {
"question": {"type": "text"},
"doc": {"type": "text"}
}
}
}
ddl_index_settings = {
"settings": {
"index": {
"number_of_shards": 6,
"number_of_replicas": 2
}
},
"mappings": {
"properties": {
"ddl": {"type": "text"},
"doc": {"type": "text"}
}
}
}
question_sql_index_settings = {
"settings": {
"index": {
"number_of_shards": 6,
"number_of_replicas": 2
}
},
"mappings": {
"properties": {
"question": {"type": "text"},
"sql": {"type": "text"}
}
}
}
# 从配置中获取自定义索引设置
if config is not None and "es_document_index_settings" in config:
document_index_settings = config["es_document_index_settings"]
if config is not None and "es_ddl_index_settings" in config:
ddl_index_settings = config["es_ddl_index_settings"]
if config is not None and "es_question_sql_index_settings" in config:
question_sql_index_settings = config["es_question_sql_index_settings"]
# 将索引设置保存为类的属性
self.document_index_settings = document_index_settings
self.ddl_index_settings = ddl_index_settings
self.question_sql_index_settings = question_sql_index_settings
# 初始化 OpenSearch 客户端
es_urls = None
if config is not None and "es_urls" in config:
es_urls = config["es_urls"]
# 获取主机和端口配置
host = config["es_host"] if config and "es_host" in config else "localhost"
port = config["es_port"] if config and "es_port" in config else 9200
ssl = config["es_ssl"] if config and "es_ssl" in config else False
verify_certs = config["es_verify_certs"] if config and "es_verify_certs" in config else False
# 获取认证配置
auth = (config["es_user"], config["es_password"]) if config and "es_user" in config else None
# 基于 base64 的认证
headers = None
if config and "es_encoded_base64" in config and "es_user" in config and "es_password" in config:
if config["es_encoded_base64"]:
encoded_credentials = base64.b64encode(
(config["es_user"] + ":" + config["es_password"]).encode("utf-8")
).decode("utf-8")
headers = {'Authorization': 'Basic ' + encoded_credentials}
auth = None
# 自定义 headers
if config and "es_headers" in config:
headers = config["es_headers"]
# 获取超时和重试配置
timeout = config["es_timeout"] if config and "es_timeout" in config else 60
max_retries = config["es_max_retries"] if config and "es_max_retries" in config else 10
es_http_compress = config["es_http_compress"] if config and "es_http_compress" in config else False
print("OpenSearch_VectorStore initialized with es_urls: ", es_urls,
" host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ",
verify_certs, " timeout: ", timeout, " max_retries: ", max_retries)
# 初始化 OpenSearch 客户端
if es_urls is not None:
self.client = OpenSearch(
hosts=[es_urls],
http_compress=es_http_compress,
use_ssl=ssl,
verify_certs=verify_certs,
timeout=timeout,
max_retries=max_retries,
retry_on_timeout=True,
http_auth=auth,
headers=headers
)
else:
self.client = OpenSearch(
hosts=[{'host': host, 'port': port}],
http_compress=es_http_compress,
use_ssl=ssl,
verify_certs=verify_certs,
timeout=timeout,
max_retries=max_retries,
retry_on_timeout=True,
http_auth=auth,
headers=headers
)
print("OpenSearch_VectorStore initialized with client over ")
# 执行一个简单的查询来检查连接
try:
print('Connected to OpenSearch cluster:')
info = self.client.info()
print('OpenSearch cluster info:', info)
except Exception as e:
print('Error connecting to OpenSearch cluster:', e)
# 如果索引不存在,则创建索引
self.create_index_if_not_exists(self.document_index, self.document_index_settings)
self.create_index_if_not_exists(self.ddl_index, self.ddl_index_settings)
self.create_index_if_not_exists(self.question_sql_index, self.question_sql_index_settings)
# 创建索引的方法
def create_index(self):
for index in [self.document_index, self.ddl_index, self.question_sql_index]:
try:
self.client.indices.create(index)
except Exception as e:
print("Error creating index: ", e)
print(f"opensearch index {index} already exists")
pass
# 如果索引不存在,则创建索引的方法
def create_index_if_not_exists(self, index_name: str, index_settings: dict) -> bool:
try:
if not self.client.indices.exists(index_name):
print(f"Index {index_name} does not exist. Creating...")
self.client.indices.create(index=index_name, body=index_settings)
return True
else:
print(f"Index {index_name} already exists.")
return False
except Exception as e:
print(f"Error creating index: {index_name} ", e)
return False
# 添加 DDL 文档的方法
def add_ddl(self, ddl: str, **kwargs) -> str:
id = str(uuid.uuid4()) + "-ddl" # 生成唯一标识符
ddl_dict = {"ddl": ddl}
response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id, **kwargs)
return response['_id']
# 添加文档的方法
def add_documentation(self, doc: str, **kwargs) -> str:
id = str(uuid.uuid4()) + "-doc" # 生成唯一标识符
doc_dict = {"doc": doc}
response = self.client.index(index=self.document_index, id=id, body=doc_dict, **kwargs)
return response['_id']
# 添加问题和 SQL 的方法
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
id = str(uuid.uuid4()) + "-sql" # 生成唯一标识符
question_sql_dict = {"question": question, "sql": sql}
response = self.client.index(index=self.question_sql_index, body=question_sql_dict, id=id, **kwargs)
return response['_id']
# 获取相关 DDL 文档的方法
def get_related_ddl(self, question: str, **kwargs) -> List[str]:
query = {"query": {"match": {"ddl": question}}}
print(query)
response = self.client.search(index=self.ddl_index, body=query, **kwargs)
return [hit['_source']['ddl'] for hit in response['hits']['hits']]
# 获取相关文档的方法
def get_related_documentation(self, question: str, **kwargs) -> List[str]:
query = {"query": {"match": {"doc": question}}}
print(query)
response = self.client.search(index=self.document_index, body=query, **kwargs)
return [hit['_source']['doc'] for hit in response['hits']['hits']]
# 获取相似问题和 SQL 的方法
def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
query = {"query": {"match": {"question": question}}}
print(query)
response = self.client.search(index=self.question_sql_index, body=query, **kwargs)
return [(hit['_source']['question'], hit['_source']['sql']) for hit in response['hits']['hits']]
# 获取训练数据的方法
def get_training_data(self, **kwargs) -> pd.DataFrame:
data = []
# 从文档索引中获取数据
response = self.client.search(index=self.document_index, body={"query": {"match_all": {}}}, size=1000)
for hit in response['hits']['hits']:
data.append({"id": hit["_id"], "training_data_type": "documentation", "question": "", "content": hit["_source"]['doc']})
# 从问题和 SQL 索引中获取数据
response = self.client.search(index=self.question_sql_index, body={"query": {"match_all": {}}}, size=1000)
for hit in response['hits']['hits']:
data.append({"id": hit["_id"], "training_data_type": "sql", "question": hit.get("_source", {}).get("question", ""), "content": hit.get("_source", {}).get("sql", "")})
# 从 DDL 索引中获取数据
response = self.client.search(index=self.ddl_index, body={"query": {"match_all": {}}}, size=1000)
for hit in response['hits']['hits']:
data.append({"id": hit["_id"], "training_data_type": "ddl", "question": "", "content": hit["_source"]['ddl']})
# 返回包含所有数据的 pandas DataFrame
return pd.DataFrame(data)
# 删除训练数据的方法
def remove_training_data(self, id: str, **kwargs) -> bool:
try:
if id.endswith("-sql"):
self.client.delete(index=self.question_sql_index, id=id)
return True
elif id.endswith("-ddl"):
self.client.delete(index=self.ddl_index, id=id, **kwargs)
return True
elif id.endswith("-doc"):
self.client.delete(index=self.document_index, id=id, **kwargs)
return True
else:
return False
except Exception as e:
print("Error deleting training data: ", e)
return False
# 生成嵌入的方法(空方法)
def generate_embedding(self, data: str, **kwargs) -> list[float]:
pass # OpenSearch 不需要生成嵌入
# 示例初始化调用
# OpenSearch_VectorStore.__init__(self, config={'es_urls': "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user': "admin", 'es_password': "admin", 'es_verify_certs': True})
# OpenSearch_VectorStore.__init__(self, config={'es_host': "https://opensearch-node.test.com", 'es_port': 9200, 'es_user': "admin", 'es_password': "admin", 'es_verify_certs': True})
功能和作用
1.初始化:
-
配置索引名称和设置。
-
初始化 OpenSearch 客户端,支持多种配置选项(如认证、超时、重试等)。
-
创建必要的索引。
2.索引管理: -
create_index_if_not_exists:检查索引是否存在,不存在则创建。
-
create_index:创建指定的索引。
3.文档操作: -
add_ddl、add_documentation、add_question_sql:添加 DDL 文档、普通文档和问题-SQL 对。
-
get_related_ddl、get_related_documentation、get_similar_question_sql:检索相关的 DDL 文档、普通文档和问题-SQL 对。
-
get_training_data:获取所有训练数据,返回 pandas DataFrame。
-
remove_training_data:删除指定的训练数据。