实现的基本思路
- 获取用户输入的问题
- 通过提示词让大模型生成SQL代码
- 提取纯SQL代码
- 执行SQL查询
- 综合数据库信息、生成SQL信息、执行结果信息组装最终回答的prompt
- 基于prompt回答用户问题
以下是实现的代码:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough,RunnableLambda
from langchain_community.utilities import SQLDatabase
from langchain_community.chat_models import ChatTongyi
# 创建数据库
db = SQLDatabase.from_uri("sqlite:///ali_langchain.db")
# 获取数据表schema信息
def get_schema(_):
return db.get_table_info()
# 用数据库执行sql代码
def run_query(query):
return db.run(query)
# 提取大模型回答中的纯SQL代码
def get_res_sql(res):
return res.split("```sql")[1].split("```")[0]
template_sql = """
请通过编写SQL代码来回答用户提的问题,回答需要基于如下数据库信息:{info}
注意,仅需要通过SQL代码回答,不需要额外添加说明文字。
代码示例如下:
\n\n```sql\nSELECT COUNT(*) FROM alidata;\n```
需要回答的问题是:{question}
"""
# 创建提示词模板
prompt = PromptTemplate.from_template(template_sql)
# 创建大模型
llm = ChatTongyi()
# 创建一个自定义链,获取最终可执行SQL代码
chain_sql = ({"info":get_schema,"question":RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
| RunnableLambda(get_res_sql))
final_template = """
请参考如下信息,给出问题的自然语言回答:
1.数据库信息:{info}
2.用户提问:{question}
3.SQL代码:{query}
4.SQL代码执行结果:{sql_exec_result}
"""
final_prompt = PromptTemplate.from_template(final_template)
# 创建一个整合链,用来最终回复用户
final_chain = ({"info":get_schema,"question":RunnablePassthrough(),"query":chain_sql}
| RunnablePassthrough.assign(sql_exec_result=lambda x:run_query(x["query"]))
| final_prompt
| llm
| StrOutputParser()
)
final_response = final_chain.invoke("记录数最多的五个brand是哪些?")
print(final_response)