LLM:企业AI应用实例(3)-text2SQL

适用于LLM初学者。学习调用常规的大语言模型,并进行自然语言生成SQL,并取数。

1 前言

随着大模型应用的普及,一些企业开始尝试将各种场景与大模型进行融合。本文旨在利用大模型生成SQL。

2 Langchain方式

LangChain 是一个用于开发由大型语言模型 (LLM) 提供支持的应用程序的框架。简化了 LLM 应用程序生命周期的每个阶段。因此非常适合新手,并且有丰富的文档供新手上路。链接:
快速入门 |🦜️🔗 LangChain 语言链icon-default.png?t=O83Ahttps://python.langchain.com/v0.1/docs/use_cases/sql/quickstart/

这里我们使用SQLchain来快速开始。

2.1 SQLchain

从高层次来看,任何 SQL 链和代理的步骤都是:

  1. 将问题转换为 SQL 查询:模型将用户输入转换为 SQL 查询。
  2. 执行SQL查询:执行SQL查询。
  3. 回答问题:模型使用查询结果响应用户输入。

2.2 第一个Demo

2.2.1 获取所需的包并设置环境变量

安装langchain(可以参考langchain的文档)

在终端中执行

!pip install --upgrade --quiet  langchain langchain-community langchain-openai

这里以笔者为例,准备实例数据


# 准备实例数据
PS C:\Users\li_zh> cd "D:\learn_project\sqlchain"
PS D:\learn_project\sqlchain> sqlite3 Chinook.db
SQLite version 3.41.2 2023-03-22 11:56:21
Enter ".help" for usage hints.
sqlite> .read Chinook_Sqlite.sql
sqlite> SELECT * FROM Artist LIMIT 10;
1|AC/DC
2|Accept
3|Aerosmith
4|Alanis Morissette
5|Alice In Chains
6|Ant?nio Carlos Jobim
7|Apocalyptica
8|Audioslave
9|BackBeat
10|Billy Cobham
sqlite>

2.2.2 使用 SQLite 与 Chinook 数据库连接

按照上述方法,数据已经在我们的目录中,现在可以使用 SQLAlchemy 驱动的类与它进行交互

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

# "[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

2.2.3 将问题转换为 SQL

我们创建一个简单的链,它接受一个问题,将其转换为 SQL 查询,执行查询,并使用结果来回答原始问题。SQL 链或代理的第一步是获取用户输入并将其转换为 SQL 查询。LangChain 为此自带了一个内置链。

注意:

填写自己的的openai_api_baseopenai_api_key,国内很难直接使用gpt,所以找个中转即可。

from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(openai_api_base="https://aihubmix.com/v1", 
                 openai_api_key=OPENAI_API_KEY,
                 model="gpt-3.5-turbo", 
                 temperature=0)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
response
# 'SELECT COUNT("EmployeeId") AS "TotalEmployees" FROM "Employee"'

db.run(response)
# '[(8,)]'

如上,我们跑完了第一个demo

用户提问:How many employees are there?

模型返回结果:response = 'SELECT COUNT("EmployeeId") AS "TotalEmployees" FROM "Employee"'

执行sql后返回的结果:8

2.2.4 用langchain的语法,合起来执行

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
chain.invoke({"question": "How many employees are there"})

和之前的内容无差比,只不过用langchain推荐的语法写 

2.2.5 添加大模型的回答

from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "How many employees are there"})

# 'There are a total of 8 employees.'

回答:# 'There are a total of 8 employees.'

3 实操补充

真正使用的时候,数据源是咱们自己的,因此跑完官方给的Demo后,必不可免的是替换sql库。

3.1 配置好数据库

数据库的配置对没使用过的新手来说可能有困难,笔者简述一下,有基础的朋友可跳过此章节。

3.1.1 MySQL安装

MySQL :: Begin Your Download

按照官网教程安装数据库

3.1.2 DBeaver安装

dbeaver是一款的数据库连接工具,免费,跨平台。

dbeaver安装和使用教程_dbeaver使用教程-CSDN博客icon-default.png?t=O83Ahttps://blog.csdn.net/tennysonsky/article/details/122397486

照教程把DBeaver安装好。官网如下

Download | DBeaver Community

然后用DBeaver与数据库建立连接(有困难的朋友多搜一下,教程都有)

最后,别忘了将数据放入库中

3.2 建立连接并测试

开始前别忘了安装

! pip install mysql-connector-python

建立连接并测试

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("mysql+mysqlconnector://tom:818200@localhost/world")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM city LIMIT 10;")



from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(openai_api_base="https://aihubmix.com/v1",
                 openai_api_key=OPENAI_API_KEY,
                 model="gpt-3.5-turbo",
                 temperature=0)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "酒店id是9003051,2022年4月18日的客房收入是多少?"})
print(response)
db.run(response)

其余操作与2章相似。

4 few-shoot与RAG2SQL

4.1 概述

简单来说,text2SQL的实现如下图:

如何进一步提高正确率是我们关心的,因为实际使用中,SQL的输出并不是每次都是我们想要的(格式、语法、随机性的影响等)

4.2 实现

4.2.1 准备环境

这里我们使用的deepseek的模型

from openai import OpenAI
import os
os.environ['DASHSCOPE_API_KEY']='sk-806b105ccce348079xxxxxxxxxxxxx'
import dashscope
from dashscope import Generation
from http import HTTPStatus
from dashscope import TextEmbedding
import pandas as pd
import os
from time import sleep
import pickle
import json
import numpy as np
import requests
import timeit

4.2.2 系统提示词

def get_prompt1(data_structure):
    prompt1=f'''
    你是一个数据分析专家!请使用下面的数据结构信息,该数据结构信息用json格式呈现,请通过duckdb sql数据分析回答用户的问题。

    数据结构信息:
    {data_structure}
    
    '''
    return prompt1

数据结构信息:

# 数据结构信息
daily_data_structure=[{
    "table_name": "ads_powerboss_revenue_hotel_d",
    "数据结构整体概述": "这是每日酒店经营数据表,这是一张以日为单位的记录酒店的收入、收益相关指标的数据表。数据包含以下列:create_day, hotel_id, crevenuetotal, cbudgetrevenuetotal, crevenuetotalcompleterate, crevenueroom, crevenueroomnohourrent, crevenueroomhourrent, crevenuenoroom, crevenuemeetingroom, crevenuedinner, crevenuemembercard, crevenueother, croomcount, croomday, croomdaynohourrent, croomdayhourrent, revpar, occ, occnohourrent, adr, weeklabel, l1, l2, l3, l4, etl_ins_tm ",
    "每列数据结构的详细分析说明": [
        {
            "create_day": {
                "中文名称": "日期",
                "描述": "每条记录的发生日期,精确到每一天,日期格式为yyyy-MM-dd,即yyyy年,MM月,dd日",
                "数据类型": "DATE",
                "数据格式": "yyyy-MM-dd",
                "示例": "2023-01-01"
            }
        },
        {
            "hotel_id": {
                "中文名称": "酒店ID",
                "描述": "每家酒店的唯一标识符,通常由一串数字组成",
                "数据类型": "VARCHAR",
                "示例": "1000111"
            }
        },
        {
            "crevenuetotal":{
                "中文名称":"每日总营收",
                "描述":"酒店每天的总营业收入",
                "数据类型":"DECIMAL",
                "示例":"27146.9"
            }
        },
        {
            "cbudgetrevenuetotal": {
                "中文名称": "总营收预算",
                "描述": "总营收预算预算",
                "数据类型": "DECIMAL",
                "示例": "30000.0"
            }
        },
        {
            "etl_ins_tm": {
                "中文名称": "ETL插入时间",
                "描述": "数据被插入到数据仓库的时间",
                "数据类型": "VARCHAR",
                "示例": "2023-01-01 12:00:00"
                }
            }
        ]
    }
]

如上图,为了更好的准确率,数据结构需要被定义清楚,给大模型详细的信息。

4.2.3 整体提示词

def get_prompt200(query,sql_rag,h_id):
    prompt200=f'''
    约束:用户问的所有的问题都是关于酒店ID为{h_id}这家酒店的,因此如果问题中没有指明是哪家酒店或者问题中提到本店、这家店等词汇,请默认hotel_id为{h_id}。
    
    在回答问题时请参考如下模板:
    {sql_rag}
    
    请一步一步思考,给出回答,并务必确保你的回答内容格式如下:
    <api-call><summary>用数据分析的方式简要回答用户的问题</summary><sql>正确的mysql数据分析sql</sql></api-call>
    注意:你只能输出该格式,禁止输出其他任何内容或思考过程。
    
    用户问题:
    {query}

    '''
    return prompt200

4.2.4 连接数据库

from langchain_community.utilities import SQLDatabase
import ast
import pandas as pd
db = SQLDatabase.from_uri("mysql+mysqlconnector://tom:xxxxx314+@rm-uf6o98xxxxxxxxxxo.mysql.rds.aliyuncs.com/my_test")
print("Opened database successfully")
print(db.dialect)
print(db.get_usable_table_names())

db.get_table_info(table_names=["ads_powerboss_revenue_hotel_d"])  # 查看建表信息
result = db.run_no_throw("select * from ads_powerboss_revenue_hotel_d limit 5", 
                          include_columns=True)  # 执行sql

查看deepseek信息

import requests

url = "https://api.deepseek.com/v1/models"

payload={}
headers = {
  'Accept': 'application/json',
  'Authorization': 'Bearer sk-xxxxxxxxxxxxxxxxxxx'
}

response = requests.request("GET", url, headers=headers, data=payload)
print(response)
print(response.text)

4.2.5 llm_request

写好大模型的请求函数

def llm_request(content1,content2):
    # client = OpenAI(api_key="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxx", base_url="https://api.deepseek.com/v1")
    client = OpenAI(api_key="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", base_url="https://aihubmix.com/v1")

    response = client.chat.completions.create(
        model="deepseek-coder",
        messages=[
            {"role": "system", "content": content1},
            {"role": "user", "content": content2},
        ],
        temperature=0.1
    )
    results=response.choices[0].message.content
    return results

4.2.6 添加few-shoot

这里是专门添加的few-shoot是为大模型准备的一些标准示例,让大模型能够理解并输出我们希望的sql语句。

注:如果few-shoot不是固定的,而是用过rag的方式调取,则变为了常说的RAG2SQL

start = timeit.default_timer()
h_id = '9ccc222'
query="昨日的时租房营收是多少?"
sql_rag = f'''
query1 = 昨日的总营收怎么样?
sql1:
SELECT 
  (SELECT crevenuetotal 
   FROM ads_powerboss_revenue_hotel_d 
   WHERE hotel_id = '0500911' 
   AND create_day = 
     (SELECT DATE_FORMAT(prev1daydate, '%Y-%m-%d')
      FROM com_day_d
      WHERE date_name = CURDATE()) 
  ) AS revenue_previous_day,

  (SELECT crevenuetotal 
   FROM ads_powerboss_revenue_hotel_d 
   WHERE hotel_id = '0500911' 
   AND create_day = 
     (SELECT DATE_FORMAT(prev1weekdate, '%Y-%m-%d')
      FROM com_day_d
      WHERE date_name = DATE_SUB(CURDATE(), INTERVAL 1 DAY)) 
  ) AS revenue_previous_week,

  (SELECT crevenuetotal 
   FROM ads_powerboss_revenue_hotel_d 
   WHERE hotel_id = '0500911' 
   AND create_day = 
     (SELECT DATE_FORMAT(prev1yearweekdate, '%Y-%m-%d')
      FROM com_day_d
      WHERE date_name = DATE_SUB(CURDATE(), INTERVAL 1 DAY)) 
  ) AS revenue_previous_year;
  
query2 = 昨日的总营收是多少?
sql2:
SELECT crevenuetotal
FROM ads_powerboss_revenue_hotel_d
WHERE hotel_id = '0500911' 
AND create_day = DATE_SUB(CURDATE(), INTERVAL 1 DAY)
  
query3 = 本周的总营收是多少?
sql3:
SELECT crevenuetotal
FROM ads_powerboss_revenue_hotel_w
WHERE hotel_id = '0500911'
AND week_key2 in (select week_key2 from com_day_d where date_name = CURDATE())

query4 =上周的总营收是多少?
sql4:
SELECT crevenuetotal
FROM ads_powerboss_revenue_hotel_w
WHERE hotel_id = '0500911'
AND week_key2 IN (
    SELECT week_key2 
    FROM com_day_d 
    WHERE date_name = DATE_SUB(CURDATE(), INTERVAL 7 DAY)
)

query5 = 上周三的总营收是多少?
sql5:
SELECT crevenuetotal
FROM ads_powerboss_revenue_hotel_d
WHERE hotel_id = '0500911'
AND create_day = 
(select date_name from com_day_d
where week_key1 = (select week_key1 from com_day_d where date_name = date_sub(CURDATE(),interval 1 week))
and weekday_key = 3);

'''

sys_prompt = get_prompt1(daily_data_structure)
us_prompt = get_prompt200(query, sql_rag, h_id)
answer=llm_request(sys_prompt,us_prompt)
print(answer)
end=timeit.default_timer()
print('Running time: %s Seconds'%(end-start))

4.2.7 批量测试

最后可以批量的进行尝试。

读取data.xlsx(一批问题),输出result.xlsx(模型的结果)

import time
df = pd.read_excel('data.xlsx')
sys_prompt = get_prompt1(daily_data_structure)
for i in range(len(df['query'])):
    start_time = time.time()  # 开始计时
    
    us_prompt = get_prompt200(df['query'][i], sql_rag, h_id)
    answer=llm_request(sys_prompt,us_prompt)
    print(answer)
    
    end_time = time.time()  # 结束计时
    duration = end_time - start_time  # 计算持续时间
    
    df.at[i, 'time'] = duration  # 使用.at来安全地修改DataFrame
    df.at[i, 'gpt4_answer'] = answer


df.to_excel('result.xlsx', index=False)  # index=False表示不保存行索引

大功告成,建议朋友们在自己的数据上多尝试。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

lzc1009840152

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值