Vanna源码研究之promp提示词以及train核心函数

前言

Vanna本质就是一个Python包,其核心在于RAG(检索增强)方法,通过检索增强的方式来构建Prompt,进而显著提高SQL查询生成的准确率。

对于vanna的一些概念性的东西大家都可以搜到,我没有在网上搜到关于vanna源码研究的一些比较好的博客,而自己又特别需要这一块的知识和指导,所以自己动手研究了一些成果,大家有需要的可以随时关注。话不多说直接上源码,其最重要的函数都写在了vanna\src\base\base.py

1. prompt提示词研究

# 使用 LLM 生成回答问题的 SQL 查询
    def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> str:
        if self.config is not None:
            initial_prompt = self.config.get("initial_prompt", None)  # 获取初始提示
        else:
            initial_prompt = None
         # 获取相似问题的 SQL
        question_sql_list = self.get_similar_question_sql(question, **kwargs)  
        ddl_list = self.get_related_ddl(question, **kwargs)  # 获取相关的 DDL
        doc_list = self.get_related_documentation(question, **kwargs)  # 获取相关的文档

        # 设计提示词
        prompt = self.get_sql_prompt(
        # 初始提示,通常是一些通用的指导信息,告诉 LLM 需要执行什么任务。
    		initial_prompt=initial_prompt,  
    		question=question,  # 用户提出的问题,模型需要根据这个问题生成 SQL 查询
# 类似问题和相应 SQL 查询的列表,用于给 LLM 提供参考示例,帮助它更好地理解如何生成 SQL 查询。 
    		question_sql_list=question_sql_list,  
    		ddl_list=ddl_list,  # 数据定义语言的语句列表,定义了数据库中的表结构
    		doc_list=doc_list,  # 文档列表,提供额外的上下文或背景信息,帮助 LLM 更好地理解问题的背景。
    		**kwargs,  # 其他可能的额外参数
)
        
        self.log(title="SQL Prompt", message=prompt)  # 记录生成的 SQL 提示
        llm_response = self.submit_prompt(prompt, **kwargs)  # 提交提示并获取 LLM 响应
        self.log(title="LLM Response", message=llm_response)  # 记录 LLM 响应

        if 'intermediate_sql' in llm_response:
            if not allow_llm_to_see_data:
                return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."  # 如果不允许 LLM 查看数据,则返回相应提示

            if allow_llm_to_see_data:
                intermediate_sql = self.extract_sql(llm_response)  # 提取中间 SQL

                try:
                    self.log(title="Running Intermediate SQL", message=intermediate_sql)  # 记录中间 SQL
                    df = self.run_sql(intermediate_sql)  # 运行中间 SQL 并获取结果
					
                    # 生成最终的提示词模板
                    prompt = self.get_sql_prompt(
                        initial_prompt=initial_prompt,
                        question=question,
                        question_sql_list=question_sql_list,
                        ddl_list=ddl_list,
                        doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()], # 将中间 SQL 查询的结果添加到 doc_list中,提高生成的SQL准确性
                        **kwargs,
                    )
                    self.log(title="Final SQL Prompt", message=prompt)  # 记录最终的 SQL 提示
                    
                    llm_response = self.submit_prompt(prompt, **kwargs)  # 提交提示并获取 LLM 响应
                    self.log(title="LLM Response", message=llm_response)  # 记录 LLM 响应
                except Exception as e:
                    return f"Error running intermediate SQL: {e}"  # 记录运行中间 SQL 时的错误

        return self.extract_sql(llm_response)  # 返回提取的 SQL 查询

这里我们可以看到对于prompt提示词的组装核心代码

 prompt = self.get_sql_prompt(
        # 初始提示,通常是一些通用的指导信息,告诉 LLM 需要执行什么任务。
    		initial_prompt=initial_prompt,  
    		question=question,  # 用户提出的问题,模型需要根据这个问题生成 SQL 查询
# 类似问题和相应 SQL 查询的列表,用于给 LLM 提供参考示例,帮助它更好地理解如何生成 SQL 查询。 
    		question_sql_list=question_sql_list,  
    		ddl_list=ddl_list,  # 数据定义语言的语句列表,定义了数据库中的表结构
    		doc_list=doc_list,  # 文档列表,提供额外的上下文或背景信息,帮助 LLM 更好地理解问题的背景。
    		**kwargs,  # 其他可能的额外参数
)

然后我们按住crtl+鼠标左键点击get_sql_prompt进入到这个函数中,对于这个prompt我们基本就一目了然了

 def get_sql_prompt(
        self,
        initial_prompt : str,
        question: str,
        question_sql_list: list,
        ddl_list: list,
        doc_list: list,
        **kwargs,
    ):
        """
        Example:
        ```python
        vn.get_sql_prompt(
            question="What are the top 10 customers by sales?",
            question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
            ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
            doc_list=["The customers table contains information about customers and their sales."],
        )

        ```

        This method is used to generate a prompt for the LLM to generate SQL.

        Args:
            question (str): The question to generate SQL for.
            question_sql_list (list): A list of questions and their corresponding SQL statements.
            ddl_list (list): A list of DDL statements.
            doc_list (list): A list of documentation.

        Returns:
            any: The prompt for the LLM to generate SQL.
        """

        if initial_prompt is None:
            initial_prompt = f"You are a {self.dialect} expert. " + \
            "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "

        initial_prompt = self.add_ddl_to_prompt(
            initial_prompt, ddl_list, max_tokens=self.max_tokens
        )

        if self.static_documentation != "":
            doc_list.append(self.static_documentation)

        initial_prompt = self.add_documentation_to_prompt(
            initial_prompt, doc_list, max_tokens=self.max_tokens
        )

        # 这里给出了对于模型回复的规定
        initial_prompt += (
            "===Response Guidelines \n"
            "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
            "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
            "3. If the provided context is insufficient, please explain why it can't be generated. \n"
            "4. Please use the most relevant table(s). \n"
            "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
1.如果提供的上下文足够,请生成一个有效的SQL查询,不对问题进行任何解释。
2.如果提供的上下文几乎足够,但需要了解特定列中的特定字符串,请生成一个中间SQL查询来查找该列中的不同字符串。在查询前添加一条注释,说明intermediate_SQL
3.如果提供的上下文不充分,请解释为什么无法生成
4.请使用最相关的表。
5.如果问题之前已经被问过并回答过,请完全按照之前给出的答案重复。
        )

        message_log = [self.system_message(initial_prompt)]

        for example in question_sql_list:
            if example is None:
                print("example is None")
            else:
                if example is not None and "question" in example and "sql" in example:
                    message_log.append(self.user_message(example["question"]))
                    message_log.append(self.assistant_message(example["sql"]))

        message_log.append(self.user_message(question))

        return message_log

在这个函数中,它首先给出了一个样例,可以帮助我们更快的了解这个函数的用法,相信对于这个使用案例大家肯定一目了然了

        vn.get_sql_prompt(
            question="What are the top 10 customers by sales?",
            question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
            ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
            doc_list=["The customers table contains information about customers and their sales."],
        )

在vanna提示词中会把DDL,[{question:sql}]一系列相似的问题sql对还有doc解释文档,组合起来,形成第一轮的提示词,这里为什么说是第一轮呢,可以这一步

initial_prompt += (
            "===Response Guidelines \n"
            "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
            "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
            "3. If the provided context is insufficient, please explain why it can't be generated. \n"
            "4. Please use the most relevant table(s). \n"
            "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"

我们对其进行翻译,可以得知,他对于模型的回复要求如下,这也会放到提示词中
1.如果提供的上下文足够,请生成一个有效的SQL查询,不对问题进行任何解释。
2.如果提供的上下文几乎足够,但需要了解特定列中的特定字符串,请生成一个中间SQL查询来查找该列中的不同字符串。在查询前添加一条注释,说明intermediate_SQL
3.如果提供的上下文不充分,请解释为什么无法生成
4.请使用最相关的表。
5.如果问题之前已经被问过并回答过,请完全按照之前给出的答案重复。

这里我们可以看到在第二个要求当中,如果有但需要了解特定列中的特定字符串,则会生成中间sql,这时我们在回到最开始的函数generate_sql

if 'intermediate_sql' in llm_response:
            if not allow_llm_to_see_data:
                return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."  # 如果不允许 LLM 查看数据,则返回相应提示

            if allow_llm_to_see_data:
                intermediate_sql = self.extract_sql(llm_response)  # 提取中间 SQL

                try:
                    self.log(title="Running Intermediate SQL", message=intermediate_sql)  # 记录中间 SQL
                    df = self.run_sql(intermediate_sql)  # 运行中间 SQL 并获取结果
					
                    # 生成最终的提示词模板
                    prompt = self.get_sql_prompt(
                        initial_prompt=initial_prompt,
                        question=question,
                        question_sql_list=question_sql_list,
                        ddl_list=ddl_list,
                        doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()], # 将中间 SQL 查询的结果添加到 doc_list中,提高生成的SQL准确性
                        **kwargs,
                    )
在这次生成的提示词模板中,他会将中间sql的执行结果以dataframe的形式添加到doc_list中,
并最终填充进prompt提示词中

以上就是对于vanna提示词设计的研究。

2. train()函数的作用

train函数见名知意,可以知道他是用来训练模型的。借助数据库的DDL语句、元数据(数据库内关于自身数据的描述信息)、相关文档说明、参考样例SQL等训练一个RAG的“模型”(embedding+向量库);
并在收到用户自然语言描述的问题时,从RAG模型中通过语义检索出相关的内容,进而组装进入Prompt,然后交给LLM生成SQL
那么这个训练实际是做了什么工作呢?
我们可以看他的源码工作流

# 训练 Vanna.AI
	def train(
	        self,
	        question: str = None,  # 用户问题
	        sql: str = None,  # SQL 查询语句
	        ddl: str = None,  # 数据定义语言语句
	        documentation: str = None,  # 文档内容
	        plan: TrainingPlan = None,  # 训练计划
	    ) -> str:
	        # 如果提供了问题但没有提供 SQL 查询,抛出验证错误
	        if question and not sql:
	            raise ValidationError("请同时提供 SQL 查询语句")
	
	        # 如果提供了文档内容,打印提示信息并添加文档
	        if documentation:
	            print("添加文档....")
	            return self.add_documentation(documentation)
	
	        # 如果提供了 SQL 查询语句,处理问题和 SQL
	        if sql:
	            if question is None:
	                question = self.generate_question(sql)  # 生成与 SQL 相关的问题
	                print("使用 SQL 生成的问题:", question, "\n添加 SQL...")
	            return self.add_question_sql(question=question, sql=sql)
	
	        # 如果提供了 DDL 语句,打印提示信息并添加 DDL
	        if ddl:
	            print("添加 DDL:", ddl)
	            return self.add_ddl(ddl)
	
	        # 如果提供了训练计划,依次处理每个计划项
	        if plan:
	            for item in plan._plan:
	                if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
	                    self.add_ddl(item.item_value)  # 添加 DDL 项
	                elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
	                    self.add_documentation(item.item_value)  # 添加文档项
	                elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
	                    self.add_question_sql(question=item.item_name, sql=item.item_value)  # 添加 SQL 项

如果我们ctrl+鼠标左键,可以看到一系列的抽象方法,

  1. add_documentation
  2. add_question_sql
  3. add_ddl

此时你可以在项目里ctrl+shift+f进行全局搜索,看看在他的实现函数在哪里。
我们知道vanna支持多种向量数据库这里我们以Chrom向量数据库为例

我们可以在vanna\src\vanna\chromadb\chromadb_vector.py里找到对应抽象函数的实现
这里我们以add_documentation为例讲一下,其他两个类似

    def add_documentation(self, documentation: str, **kwargs) -> str:
    #生成一个确定性的UUID,将其与字符串"-doc"拼接在一起,生成最终的文档ID
        id = deterministic_uuid(documentation) + "-doc"
   #这行代码调用类的documentation_collection属性中的add方法,准备将文档添加到文档集合中。
        self.documentation_collection.add(
   #将文档内容(即documentation字符串)传递给documents参数
            documents=documentation,
     #将生成的文档嵌入(即文档的向量表示)传递给embeddings参数
            embeddings=self.generate_embedding(documentation),
            ids=id,
        )
        return id
 总结来说,这个函数的功能是将一个文档添加到一个文档集合中,并返回该文档的唯一ID

到这你可能还不太清楚train到底做了什么,可以回到源头,其核心在于RAG(检索增强)方法,通过检索增强的方式来构建prompt
其实train方法就是为了将DDL语句、元数据(描述数据库自身数据的信息)、相关文档说明、参考样例SQL等存入向量数据库。这里我们主要是讲了存入Chroma向量数据库

这里捎带的加一些RAG的知识,当我们存入向量数据库后,我们可以在回到生成提示词generate_sql()函数中

 
        question_sql_list = self.get_similar_question_sql(question, **kwargs)  
        ddl_list = self.get_related_ddl(question, **kwargs)  # 获取相关的 DDL
        doc_list = self.get_related_documentation(question, **kwargs)  # 获取相关的文档

vanna\src\vanna\chromadb\chromadb_vector.py里我们看一下它的实现函数

   def get_similar_question_sql(self, question: str, **kwargs) -> list:  # 获取类似问题的 SQL
        return ChromaDB_VectorStore._extract_documents(
            self.sql_collection.query(
                query_texts=[question],  # 查询文本
                n_results=self.n_results_sql,  # 查询结果数量
            )
        )

这里我们捋一下这个RAG实现过程

  1. 当我们需要使用vanna时,我们需要先试用train()进行训练,提供上下文,例如
vn.train(sql="SELECT COUNT(*) AS electric_shock_count FROM riskcontroller WHERE harm = '触电';")
  1. 训练完之后会存入向量数据库
  2. ask(‘潜在危险描述为触电的的条数有多少’)函数的时候
  3. 开始构建提示词,从向量数据库中进行检索,找到一系列相似的DDL、DOC、question-sql,来充当上下文
  4. 构建好的这个提示词会交给大模型
  5. 最后生成sql
  6. 如果需要后续还有对于执行sql之后使用pandas生成的图表的研究
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值