前言
vanna可实现自然语言转SQL,尝试本地部署vanna对接数据库,将自然语言转成标准的SQL对数据库进行查询。本文先对vanna源码进行分析
一、vanna源码分析
exceptions
定义了一些自定义异常,用于在 Vanna 项目中处理特定错误情况。
class ImproperlyConfigured(Exception):
"""配置错误时引发的异常。"""
pass
class DependencyError(Exception):
"""缺少依赖项时引发的异常。"""
pass
class ConnectionError(Exception):
"""连接错误时引发的异常。"""
pass
class OTPCodeError(Exception):
"""OTP无效或无法发送时引发的异常。"""
pass
class SQLRemoveError(Exception):
"""无法删除SQL时引发的异常。"""
pass
class ExecutionError(Exception):
"""代码执行失败时引发的异常。"""
pass
class ValidationError(Exception):
"""验证错误时引发的异常。"""
pass
class APIError(Exception):
"""API相关错误时引发的异常。"""
pass
flask
init.py
初始化和配置一个 Flask 应用,用于提供 Web 服务和 API 接口,以便与 Vanna 实例进行交互。
主要作用
缓存管理:定义一个缓存接口 Cache 及其内存实现 MemoryCache,用于在应用中缓存数据。
认证管理:提供认证装饰器 requires_auth,确保只有经过认证的用户才能访问某些路由。
路由定义:定义了多个 API 路由,用于处理不同的 HTTP 请求,例如生成 SQL、获取配置、管理训练数据等。
WebSocket 支持:使用 flask_sock 添加 WebSocket 支持,用于实时日志传输。
import json # 导入处理JSON数据的库
import logging # 导入日志记录库
import sys # 导入系统模块,用于与Python解释器交互
import uuid # 导入UUID模块,用于生成唯一标识符
from abc import ABC, abstractmethod # 导入抽象基类和抽象方法
from functools import wraps # 导入装饰器工具
import flask # 导入Flask库,用于创建Web应用
import requests # 导入请求库,用于发送HTTP请求
from flask import Flask, Response, jsonify, request # 从Flask库中导入相关类和方法
from flask_sock import Sock # 导入用于处理WebSocket的库
from .assets import css_content, html_content, js_content # 导入静态资源内容
from .auth import AuthInterface, NoAuth # 导入认证接口和无认证类
class Cache(ABC):
"""
定义一个用于在Flask应用中存储数据的缓存接口。
"""
@abstractmethod
def generate_id(self, *args, **kwargs):
"""
生成一个唯一的缓存ID。
"""
pass
@abstractmethod
def get(self, id, field):
"""
从缓存中获取一个值。
"""
pass
@abstractmethod
def get_all(self, field_list) -> list:
"""
从缓存中获取所有值。
"""
pass
@abstractmethod
def set(self, id, field, value):
"""
在缓存中设置一个值。
"""
pass
@abstractmethod
def delete(self, id):
"""
从缓存中删除一个值。
"""
pass
class MemoryCache(Cache):
def __init__(self):
self.cache = {} # 初始化一个空字典作为缓存
def generate_id(self, *args, **kwargs):
return str(uuid.uuid4()) # 生成一个UUID作为ID
def set(self, id, field, value):
if id not in self.cache:
self.cache[id] = {} # 如果ID不存在,则在缓存中创建一个新条目
self.cache[id][field] = value # 设置字段值
def get(self, id, field):
if id not in self.cache:
return None # 如果ID不存在,返回None
if field not in self.cache[id]:
return None # 如果字段不存在,返回None
return self.cache[id][field] # 返回字段值
def get_all(self, field_list) -> list:
return [
{"id": id, **{field: self.get(id=id, field=field) for field in field_list}}
for id in self.cache
] # 返回所有缓存条目
def delete(self, id):
if id in self.cache:
del self.cache[id] # 删除缓存条目
class VannaFlaskApp:
flask_app = None # 初始化flask_app为None
def requires_cache(self, required_fields, optional_fields=[]):
def decorator(f):
@wraps(f)
def decorated(*args, **kwargs):
id = request.args.get("id") # 从请求参数中获取ID
if id is None:
id = request.json.get("id") # 如果ID为空,从JSON数据中获取ID
if id is None:
return jsonify({"type": "error", "error": "No id provided"}) # 返回错误响应
for field in required_fields:
if self.cache.get(id=id, field=field) is None:
return jsonify({"type": "error", "error": f"No {field} found"}) # 如果必需字段不存在,返回错误响应
field_values = {
field: self.cache.get(id=id, field=field) for field in required_fields
} # 获取必需字段值
for field in optional_fields:
field_values[field] = self.cache.get(id=id
auth.py
from abc import ABC, abstractmethod # 导入抽象基类和抽象方法
import flask # 导入Flask库,用于创建Web应用
class AuthInterface(ABC):
@abstractmethod
def get_user(self, flask_request) -> any:
pass # 获取用户信息
@abstractmethod
def is_logged_in(self, user: any) -> bool:
pass # 检查用户是否已登录
@abstractmethod
def override_config_for_user(self, user: any, config: dict) -> dict:
pass # 根据用户覆盖配置
@abstractmethod
def login_form(self) -> str:
pass # 返回登录表单的HTML
@abstractmethod
def login_handler(self, flask_request) -> str:
pass # 处理登录请求
@abstractmethod
def callback_handler(self, flask_request) -> str:
pass # 处理认证回调
@abstractmethod
def logout_handler(self, flask_request) -> str:
pass # 处理注销请求
class NoAuth(AuthInterface):
def get_user(self, flask_request) -> any:
return {} # 返回空的用户信息
def is_logged_in(self, user: any) -> bool:
return True # 始终返回True,表示用户已登录
def override_config_for_user(self, user: any, config: dict) -> dict:
return config # 不修改配置,直接返回
def login_form(self) -> str:
return '' # 返回空字符串,表示不需要登录表单
def login_handler(self, flask_request) -> str:
return 'No login required' # 返回提示信息,表示不需要登录
def callback_handler(self, flask_request) -> str:
return 'No login required' # 返回提示信息,表示不需要认证回调
def logout_handler(self, flask_request) -> str:
return 'No login required' # 返回提示信息,表示不需要注销
ollama
ollama.py
- 与 Ollama 模型的通信:
Ollama
类实现了与 Ollama 模型的接口,通过 HTTP 请求与 Ollama 模型进行通信,发送提示并接收响应。 - SQL 语句提取:提供从 LLM 响应中提取 SQL 语句的方法,便于后续处理和执行。
- 日志记录:记录请求参数、提示内容和响应,便于调试和分析。
Ollama
类实现了与 Ollama 模型的基本接口,提供了初始化、消息处理、SQL 提取和提示提交等功能。该类用于与 Ollama 模型进行通信,发送提示并接收响应,同时记录相关日志以便调试和分析。
import json # 导入 json 模块,用于处理 JSON 数据
import re # 导入 re 模块,用于正则表达式操作
from httpx import Timeout # 从 httpx 库导入 Timeout 类,用于设置请求超时
from ..base import VannaBase # 从上一级目录的 base 模块导入 VannaBase 类
from ..exceptions import DependencyError # 从上一级目录的 exceptions 模块导入 DependencyError 类
class Ollama(VannaBase): # 定义一个名为 Ollama 的类,继承自 VannaBase
def __init__(self, config=None): # 定义类的构造函数,接受一个可选的配置参数 config
try:
ollama = __import__("ollama") # 尝试导入 ollama 模块
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method, run command:"
" \npip install ollama" # 如果导入失败,抛出 DependencyError 异常,并提示用户安装所需依赖
)
if not config:
raise ValueError("config must contain at least Ollama model") # 如果没有提供 config,抛出 ValueError 异常
if 'model' not in config.keys():
raise ValueError("config must contain at least Ollama model") # 如果 config 中不包含 'model' 键,抛出 ValueError 异常
self.host = config.get("ollama_host", "http://localhost:11434") # 获取 Ollama 主机地址,默认为 "http://localhost:11434"
self.model = config["model"] # 获取 Ollama 模型名称
if ":" not in self.model:
self.model += ":latest" # 如果模型名称中不包含 ":",则添加 ":latest"
self.ollama_client = ollama.Client(self.host, timeout=Timeout(240.0)) # 创建 Ollama 客户端实例,设置超时时间为 240 秒
self.keep_alive = config.get('keep_alive', None) # 获取 keep_alive 配置,默认为 None
self.ollama_options = config.get('options', {}) # 获取 Ollama 选项配置,默认为空字典
self.num_ctx = self.ollama_options.get('num_ctx', 2048) # 获取 num_ctx 配置,默认为 2048
self.__pull_model_if_ne(self.ollama_client, self.model) # 调用私有方法 __pull_model_if_ne,检查并拉取模型
@staticmethod
def __pull_model_if_ne(ollama_client, model): # 定义一个静态方法 __pull_model_if_ne,接受 ollama_client 和 model 参数
model_response = ollama_client.list() # 获取模型列表
model_lists = [model_element['model'] for model_element in
model_response.get('models', [])] # 提取模型名称列表
if model not in model_lists:
ollama_client.pull(model) # 如果模型不在列表中,拉取模型
def system_message(self, message: str) -> any: # 定义一个方法 system_message,接受一个字符串参数 message
return {"role": "system", "content": message} # 返回一个包含角色和内容的字典
def user_message(self, message: str) -> any: # 定义一个方法 user_message,接受一个字符串参数 message
return {"role": "user", "content": message} # 返回一个包含角色和内容的字典
def assistant_message(self, message: str) -> any: # 定义一个方法 assistant_message,接受一个字符串参数 message
return {"role": "assistant", "content": message} # 返回一个包含角色和内容的字典
def extract_sql(self, llm_response): # 定义一个方法 extract_sql,接受一个字符串参数 llm_response
"""
提取第一个 SQL 语句,忽略大小写,
匹配到第一个分号、三个反引号或字符串末尾,
并移除提取字符串中的三个反引号。
参数:
- llm_response (str): 要搜索的字符串。
返回:
- str: 提取的第一个 SQL 语句,移除三个反引号,如果没有匹配项,则返回空字符串。
"""
# 移除 ollama 生成的额外字符
llm_response = llm_response.replace("\\_", "_")
llm_response = llm_response.replace("\\", "")
# 正则表达式查找 ```sql' 并捕获直到 '```'
sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL)
# 正则表达式查找 'select, with (忽略大小写) 并捕获直到 ';', [ (在 mistral 中发生) 或字符串末尾
select_with = re.search(r'(select|(with.*?as \())(.*?)(?=;|\[|```)',
llm_response,
re.IGNORECASE | re.DOTALL)
if sql:
self.log(
f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
return sql.group(1).replace("```", "")
elif select_with:
self.log(
f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}")
return select_with.group(0)
else:
return llm_response
def submit_prompt(self, prompt, **kwargs) -> str: # 定义一个方法 submit_prompt,接受一个参数 prompt 和其他可选参数
self.log(
f"Ollama parameters:\n"
f"model={self.model},\n"
f"options={self.ollama_options},\n"
f"keep_alive={self.keep_alive}") # 记录 Ollama 参数
self.log(f"Prompt Content:\n{json.dumps(prompt)}") # 记录提示内容
response_dict = self.ollama_client.chat(model=self.model,
messages=prompt,
stream=False,
options=self.ollama_options,
keep_alive=self.keep_alive) # 调用 Ollama 客户端的 chat 方法,获取响应
self.log(f"Ollama Response:\n{str(response_dict)}") # 记录 Ollama 响应
return response_dict['message']['content'] # 返回响应中的消息内容
代码功能和作用
-
导入模块:
json
和re
:用于处理 JSON 数据和正则表达式操作。Timeout
:用于设置请求超时。VannaBase
和DependencyError
:用于继承基类和处理依赖错误。
-
定义
Ollama
类:- 继承自
VannaBase
,实现了与 Ollama 模型的接口。
- 继承自
-
构造函数
__init__
:- 尝试导入
ollama
模块,如果失败则抛出DependencyError
。 - 验证
config
参数是否存在并包含必要的键。 - 初始化 Ollama 客户端和模型参数。
- 调用私有方法
__pull_model_if_ne
检查并拉取模型。
- 尝试导入
-
静态方法
__pull_model_if_ne
:- 检查指定模型是否存在,如果不存在则拉取模型。
-
消息方法:
system_message
、user_message
和assistant_message
方法分别返回带有角色和内容的字典,用于与 Ollama 模型的通信。
-
extract_sql
方法:- 从 LLM 响应中提取 SQL 语句,移除生成的额外字符,并使用正则表达式匹配和提取 SQL 语句。
-
submit_prompt
方法:- 记录请求参数和提示内容,调用 Ollama 客户端的
chat
方法获取响应,并返回响应中的消息内容。
- 记录请求参数和提示内容,调用 Ollama 客户端的