vanna学习日志(三)


前言

vanna可实现自然语言转SQL,尝试本地部署vanna对接数据库,将自然语言转成标准的SQL对数据库进行查询。本文先对vanna源码进行分析

一、vanna源码分析

exceptions

定义了一些自定义异常,用于在 Vanna 项目中处理特定错误情况。
class ImproperlyConfigured(Exception):
    """配置错误时引发的异常。"""
    pass


class DependencyError(Exception):
    """缺少依赖项时引发的异常。"""
    pass


class ConnectionError(Exception):
    """连接错误时引发的异常。"""
    pass


class OTPCodeError(Exception):
    """OTP无效或无法发送时引发的异常。"""
    pass


class SQLRemoveError(Exception):
    """无法删除SQL时引发的异常。"""
    pass


class ExecutionError(Exception):
    """代码执行失败时引发的异常。"""
    pass


class ValidationError(Exception):
    """验证错误时引发的异常。"""
    pass


class APIError(Exception):
    """API相关错误时引发的异常。"""
    pass

flask

init.py

初始化和配置一个 Flask 应用,用于提供 Web 服务和 API 接口,以便与 Vanna 实例进行交互。

主要作用
缓存管理:定义一个缓存接口 Cache 及其内存实现 MemoryCache,用于在应用中缓存数据。
认证管理:提供认证装饰器 requires_auth,确保只有经过认证的用户才能访问某些路由。
路由定义:定义了多个 API 路由,用于处理不同的 HTTP 请求,例如生成 SQL、获取配置、管理训练数据等。
WebSocket 支持:使用 flask_sock 添加 WebSocket 支持,用于实时日志传输。
import json  # 导入处理JSON数据的库
import logging  # 导入日志记录库
import sys  # 导入系统模块,用于与Python解释器交互
import uuid  # 导入UUID模块,用于生成唯一标识符
from abc import ABC, abstractmethod  # 导入抽象基类和抽象方法
from functools import wraps  # 导入装饰器工具

import flask  # 导入Flask库,用于创建Web应用
import requests  # 导入请求库,用于发送HTTP请求
from flask import Flask, Response, jsonify, request  # 从Flask库中导入相关类和方法
from flask_sock import Sock  # 导入用于处理WebSocket的库

from .assets import css_content, html_content, js_content  # 导入静态资源内容
from .auth import AuthInterface, NoAuth  # 导入认证接口和无认证类

class Cache(ABC):
    """
    定义一个用于在Flask应用中存储数据的缓存接口。
    """

    @abstractmethod
    def generate_id(self, *args, **kwargs):
        """
        生成一个唯一的缓存ID。
        """
        pass

    @abstractmethod
    def get(self, id, field):
        """
        从缓存中获取一个值。
        """
        pass

    @abstractmethod
    def get_all(self, field_list) -> list:
        """
        从缓存中获取所有值。
        """
        pass

    @abstractmethod
    def set(self, id, field, value):
        """
        在缓存中设置一个值。
        """
        pass

    @abstractmethod
    def delete(self, id):
        """
        从缓存中删除一个值。
        """
        pass

class MemoryCache(Cache):
    def __init__(self):
        self.cache = {}  # 初始化一个空字典作为缓存

    def generate_id(self, *args, **kwargs):
        return str(uuid.uuid4())  # 生成一个UUID作为ID

    def set(self, id, field, value):
        if id not in self.cache:
            self.cache[id] = {}  # 如果ID不存在,则在缓存中创建一个新条目

        self.cache[id][field] = value  # 设置字段值

    def get(self, id, field):
        if id not in self.cache:
            return None  # 如果ID不存在,返回None

        if field not in self.cache[id]:
            return None  # 如果字段不存在,返回None

        return self.cache[id][field]  # 返回字段值

    def get_all(self, field_list) -> list:
        return [
            {"id": id, **{field: self.get(id=id, field=field) for field in field_list}}
            for id in self.cache
        ]  # 返回所有缓存条目

    def delete(self, id):
        if id in self.cache:
            del self.cache[id]  # 删除缓存条目

class VannaFlaskApp:
    flask_app = None  # 初始化flask_app为None

    def requires_cache(self, required_fields, optional_fields=[]):
        def decorator(f):
            @wraps(f)
            def decorated(*args, **kwargs):
                id = request.args.get("id")  # 从请求参数中获取ID

                if id is None:
                    id = request.json.get("id")  # 如果ID为空,从JSON数据中获取ID
                    if id is None:
                        return jsonify({"type": "error", "error": "No id provided"})  # 返回错误响应

                for field in required_fields:
                    if self.cache.get(id=id, field=field) is None:
                        return jsonify({"type": "error", "error": f"No {field} found"})  # 如果必需字段不存在,返回错误响应

                field_values = {
                    field: self.cache.get(id=id, field=field) for field in required_fields
                }  # 获取必需字段值

                for field in optional_fields:
                    field_values[field] = self.cache.get(id=id

auth.py

from abc import ABC, abstractmethod  # 导入抽象基类和抽象方法

import flask  # 导入Flask库,用于创建Web应用

class AuthInterface(ABC):
    @abstractmethod
    def get_user(self, flask_request) -> any:
        pass  # 获取用户信息

    @abstractmethod
    def is_logged_in(self, user: any) -> bool:
        pass  # 检查用户是否已登录

    @abstractmethod
    def override_config_for_user(self, user: any, config: dict) -> dict:
        pass  # 根据用户覆盖配置

    @abstractmethod
    def login_form(self) -> str:
        pass  # 返回登录表单的HTML

    @abstractmethod
    def login_handler(self, flask_request) -> str:
        pass  # 处理登录请求

    @abstractmethod
    def callback_handler(self, flask_request) -> str:
        pass  # 处理认证回调

    @abstractmethod
    def logout_handler(self, flask_request) -> str:
        pass  # 处理注销请求

class NoAuth(AuthInterface):
    def get_user(self, flask_request) -> any:
        return {}  # 返回空的用户信息

    def is_logged_in(self, user: any) -> bool:
        return True  # 始终返回True,表示用户已登录

    def override_config_for_user(self, user: any, config: dict) -> dict:
        return config  # 不修改配置,直接返回

    def login_form(self) -> str:
        return ''  # 返回空字符串,表示不需要登录表单

    def login_handler(self, flask_request) -> str:
        return 'No login required'  # 返回提示信息,表示不需要登录

    def callback_handler(self, flask_request) -> str:
        return 'No login required'  # 返回提示信息,表示不需要认证回调

    def logout_handler(self, flask_request) -> str:
        return 'No login required'  # 返回提示信息,表示不需要注销

ollama

ollama.py

  • 与 Ollama 模型的通信Ollama 类实现了与 Ollama 模型的接口,通过 HTTP 请求与 Ollama 模型进行通信,发送提示并接收响应。
  • SQL 语句提取:提供从 LLM 响应中提取 SQL 语句的方法,便于后续处理和执行。
  • 日志记录:记录请求参数、提示内容和响应,便于调试和分析。

Ollama 类实现了与 Ollama 模型的基本接口,提供了初始化、消息处理、SQL 提取和提示提交等功能。该类用于与 Ollama 模型进行通信,发送提示并接收响应,同时记录相关日志以便调试和分析。

import json  # 导入 json 模块,用于处理 JSON 数据
import re  # 导入 re 模块,用于正则表达式操作

from httpx import Timeout  # 从 httpx 库导入 Timeout 类,用于设置请求超时

from ..base import VannaBase  # 从上一级目录的 base 模块导入 VannaBase 类
from ..exceptions import DependencyError  # 从上一级目录的 exceptions 模块导入 DependencyError 类


class Ollama(VannaBase):  # 定义一个名为 Ollama 的类,继承自 VannaBase
  def __init__(self, config=None):  # 定义类的构造函数,接受一个可选的配置参数 config
    try:
      ollama = __import__("ollama")  # 尝试导入 ollama 模块
    except ImportError:
      raise DependencyError(
        "You need to install required dependencies to execute this method, run command:"
        " \npip install ollama"  # 如果导入失败,抛出 DependencyError 异常,并提示用户安装所需依赖
      )

    if not config:
      raise ValueError("config must contain at least Ollama model")  # 如果没有提供 config,抛出 ValueError 异常
    if 'model' not in config.keys():
      raise ValueError("config must contain at least Ollama model")  # 如果 config 中不包含 'model' 键,抛出 ValueError 异常
    self.host = config.get("ollama_host", "http://localhost:11434")  # 获取 Ollama 主机地址,默认为 "http://localhost:11434"
    self.model = config["model"]  # 获取 Ollama 模型名称
    if ":" not in self.model:
      self.model += ":latest"  # 如果模型名称中不包含 ":",则添加 ":latest"

    self.ollama_client = ollama.Client(self.host, timeout=Timeout(240.0))  # 创建 Ollama 客户端实例,设置超时时间为 240 秒
    self.keep_alive = config.get('keep_alive', None)  # 获取 keep_alive 配置,默认为 None
    self.ollama_options = config.get('options', {})  # 获取 Ollama 选项配置,默认为空字典
    self.num_ctx = self.ollama_options.get('num_ctx', 2048)  # 获取 num_ctx 配置,默认为 2048
    self.__pull_model_if_ne(self.ollama_client, self.model)  # 调用私有方法 __pull_model_if_ne,检查并拉取模型

  @staticmethod
  def __pull_model_if_ne(ollama_client, model):  # 定义一个静态方法 __pull_model_if_ne,接受 ollama_client 和 model 参数
    model_response = ollama_client.list()  # 获取模型列表
    model_lists = [model_element['model'] for model_element in
                   model_response.get('models', [])]  # 提取模型名称列表
    if model not in model_lists:
      ollama_client.pull(model)  # 如果模型不在列表中,拉取模型

  def system_message(self, message: str) -> any:  # 定义一个方法 system_message,接受一个字符串参数 message
    return {"role": "system", "content": message}  # 返回一个包含角色和内容的字典

  def user_message(self, message: str) -> any:  # 定义一个方法 user_message,接受一个字符串参数 message
    return {"role": "user", "content": message}  # 返回一个包含角色和内容的字典

  def assistant_message(self, message: str) -> any:  # 定义一个方法 assistant_message,接受一个字符串参数 message
    return {"role": "assistant", "content": message}  # 返回一个包含角色和内容的字典

  def extract_sql(self, llm_response):  # 定义一个方法 extract_sql,接受一个字符串参数 llm_response
    """
    提取第一个 SQL 语句,忽略大小写,
    匹配到第一个分号、三个反引号或字符串末尾,
    并移除提取字符串中的三个反引号。

    参数:
    - llm_response (str): 要搜索的字符串。

    返回:
    - str: 提取的第一个 SQL 语句,移除三个反引号,如果没有匹配项,则返回空字符串。
    """
    # 移除 ollama 生成的额外字符
    llm_response = llm_response.replace("\\_", "_")
    llm_response = llm_response.replace("\\", "")

    # 正则表达式查找 ```sql' 并捕获直到 '```'
    sql = re.search(r"```sql\n((.|\n)*?)(?=;|\[|```)", llm_response, re.DOTALL)
    # 正则表达式查找 'select, with (忽略大小写) 并捕获直到 ';', [ (在 mistral 中发生) 或字符串末尾
    select_with = re.search(r'(select|(with.*?as \())(.*?)(?=;|\[|```)',
                            llm_response,
                            re.IGNORECASE | re.DOTALL)
    if sql:
      self.log(
        f"Output from LLM: {llm_response} \nExtracted SQL: {sql.group(1)}")
      return sql.group(1).replace("```", "")
    elif select_with:
      self.log(
        f"Output from LLM: {llm_response} \nExtracted SQL: {select_with.group(0)}")
      return select_with.group(0)
    else:
      return llm_response

  def submit_prompt(self, prompt, **kwargs) -> str:  # 定义一个方法 submit_prompt,接受一个参数 prompt 和其他可选参数
    self.log(
      f"Ollama parameters:\n"
      f"model={self.model},\n"
      f"options={self.ollama_options},\n"
      f"keep_alive={self.keep_alive}")  # 记录 Ollama 参数
    self.log(f"Prompt Content:\n{json.dumps(prompt)}")  # 记录提示内容
    response_dict = self.ollama_client.chat(model=self.model,
                                            messages=prompt,
                                            stream=False,
                                            options=self.ollama_options,
                                            keep_alive=self.keep_alive)  # 调用 Ollama 客户端的 chat 方法,获取响应

    self.log(f"Ollama Response:\n{str(response_dict)}")  # 记录 Ollama 响应

    return response_dict['message']['content']  # 返回响应中的消息内容
代码功能和作用
  1. 导入模块

    • jsonre:用于处理 JSON 数据和正则表达式操作。
    • Timeout:用于设置请求超时。
    • VannaBaseDependencyError:用于继承基类和处理依赖错误。
  2. 定义 Ollama

    • 继承自 VannaBase,实现了与 Ollama 模型的接口。
  3. 构造函数 __init__

    • 尝试导入 ollama 模块,如果失败则抛出 DependencyError
    • 验证 config 参数是否存在并包含必要的键。
    • 初始化 Ollama 客户端和模型参数。
    • 调用私有方法 __pull_model_if_ne 检查并拉取模型。
  4. 静态方法 __pull_model_if_ne

    • 检查指定模型是否存在,如果不存在则拉取模型。
  5. 消息方法

    • system_messageuser_messageassistant_message 方法分别返回带有角色和内容的字典,用于与 Ollama 模型的通信。
  6. extract_sql 方法

    • 从 LLM 响应中提取 SQL 语句,移除生成的额外字符,并使用正则表达式匹配和提取 SQL 语句。
  7. submit_prompt 方法

    • 记录请求参数和提示内容,调用 Ollama 客户端的 chat 方法获取响应,并返回响应中的消息内容。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值