有关Text-To-SQL方法,可以查阅我的另一篇文章,Text-to-SQL方法研究
直接与数据库对话-text2sql
Text2sql就是把文本转换为sql语言,这段时间公司有这方面的需求,调研了一下市面上text2sql的方法,比如阿里的Chat2DB,麻省理工开源的Vanna。试验了一下,最终还是决定自研,基于Vanna的思想,RAG+大模型。
使用开源的Vanna实现text2sql比较方便,Vanna可以直接连接数据库,但是当用户权限能访问多个数据库的时候,就比较麻烦了,而且Vanna向量化存储之后,新的question作对比时没有区分数据库。因此自己实现了一下text2sq,仍然采用Vanna的思想,提前训练DDL,Sqlques,和数据库document。
这里简单做一下记录,以供后续学习使用。
基本思路
1、数据库DDL语句,SQL-Question,Dcoument信息获取
2、基于用户提问question和数据库Document锁定要分析的数据库
3、模型训练:借助数据库的DDL语句、元数据(描述数据库自身数据的信息)、相关文档说明、参考样例SQL等,训练一个RAG“模型”。
这一模型结合了embedding技术和向量数据库,使得数据库的结构和内容能够被高效地索引和检索。
4、语义检索: 当用户输入自然语言描述的问题时,①会从向量库里面检索,迅速找出与问题相关的内容;②进行BM25算法文本召回,找到与问题 最相关的内容;③分别使用RRF算法和Re-ranking重排序算法,锁定最相关内容
语义匹配:使用算法(如BERT等)来理解查询和文档的语义相似性
文本召回匹配:BM25算法文本召回,找到与问题最相关的内容
rerank结果重排序:对搜索结果进行排序。
5、Prompt构建: 检索到的相关信息会被组装进Prompt中,形成一个结构化的查询描述。这一Prompt随后会被传递给LLM(大型语言模型)用于生成准确的SQL查询。
实现逻辑图
实现架构图:
具体实现方式如下所示:
1.数据库的选择
class DataBaseSearch(object):
def __init__(self, _model):
self.name = 'DataBaseSearch'
self.model = _model
self.instruction = "为这段内容生成表示以用于匹配文本描述:"
self.SIZE = 1024
self.index = faiss.IndexFlatL2(self.SIZE)
self.textdata = []
self.subdata = {}
self.i2key = {}
self.id2ddls = {}
self.id2sqlques = {}
self.id2docs = {}
self.strtexts = {}
# self.ddldata = []
# self.sqlques_data = []
# self.document_data = []
self.load_textdata() # 加载text数据
self.load_textdata_vec() # text数据向量化
def load_textdata(self):
try:
response = requests.post(
url="xxx",
verify=False)
print(response.text)
jsonobj = json.loads(response.text)
textdatas = jsonobj["data"]
for textdata in textdatas: # 提取每一个数据库内容
cid = textdata["dataSetID"]
cddls = textdata["ddl"]
csql_ques = textdata["exp"]
cdocuments = textdata["Intro"]
self.textdata.append((cid, cddls, csql_ques, cdocuments)) # 整合所有数据
except Exception as e:
print(e)
# print("load textdata ", self.textdata)
def load_textdata_vec(self):
num0 = 0
for recode in self.textdata:
_id = recode[0]
_ddls = recode[1]
_sql_ques = recode[2]
_documents = recode[3]
# _strtexts = str(_ddls) + str(_sql_ques) + str(_documents)
_strtexts = str(_sql_ques) + str(_documents)
text_embeddings = self.model.encode([_strtexts], normalize_embeddings=True)
self.index.add(text_embeddings)
self.i2key[num0] = _id
self.strtexts[_id] = _strtexts
self.id2ddls[_id] = _ddls
self.id2sqlques[_id] = _sql_ques
self.id2docs[_id] = _documents
num0 += 1
# print("init instruction vec", num0)
def calculate_score(self, score, question, kws):
pass
def find_vec_database(self, question, k, theata):
# print(question)
q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)
D, I = self.index.search(q_embeddings, k)
result = []
for i in range(k):
sim_i = I[0][i]
uuid = self.i2key.get(sim_i, "none")
sim_v = D[0][i]
database_texts = self.strtexts.get(uuid, "none")
# score = self.calculate_score(sim_v, question, database_texts) # wait implement
score = int(sim_v*1000)
if score < theata:
doc = {}
doc["score"] = score
doc["dataSetID"] = uuid
result.append(doc)
# print(result)
return result
if __name__ == '__main__':
modelpath = "E:\\module\\bge-large-zh-v1.5"
model = SentenceTransformer(modelpath)
vs = DataBaseSearch(model)
result = vs.find_vec_database("查询济南市第三幼儿园所有小班班级?", 1, 2000)
print(result)
2.sql_ques:sql问题训练
class SqlQuesSearch(object):
def __init__(self, _model):
self.name = "SqlQuesSearch"
self.model = _model
self.instruction = "为这段内容生成表示以用于匹配文本描述:"
self.SIZE = 1024
self.index = faiss.IndexFlatL2(self.SIZE)
self.sqlquedata = []
self.i2dbid = {}
self.i2sqlid = {}
self.id2sqlque = {}
self.id2que = {}
self.id2sql = {}
self.dbid2sqlques = {}
#
# self.sqlques = {}
#
# self.i2key = {}
#
# self.id2sqlques = {}
#
# self.num2sqlque = {}
# self.ddldata = []
# self.sqlques_data = []
# self.document_data = []
self.load_textdata() # 加载text数据
self.load_textdata_vec() # text数据向量化
def load_textdata(self):
try:
response = requests.post(
url="xxx",
verify=False)
print(response.text)
jsonobj = json.loads(response.text)
textdatas = jsonobj["data"]
datadatas = jsonobj["data"]
for datadata in datadatas: # 提取每一个数据库sql-ques内容
dbid = datadata["dataSetID"]
sql_ques = datadata["exp"]
self.sqlquedata.append((dbid, sql_ques)) # 整合sql数据
except Exception as e:
print(e)
# print("load textdata ", self.sqlquedata)
def load_textdata_vec(self):
num0 = 0
for recode in self.sqlquedata:
db_id = recode[0]
sql_ques = recode[1]
for sql_que in sql_ques:
sql_id = sql_que["sql_id"]
question = sql_que["question"]
sql = sql_que["sql"]
ddl_embeddings = self.model.encode([question], normalize_embeddings=True)
self.index.add(ddl_embeddings)
self.i2dbid[num0] = db_id
self.i2sqlid[num0] = sql_id
self.id2que[sql_id] = question
self.id2sql[sql_id] = sql
num0 += 1
print("init sql-que vec", num0)
def calculate_score(sim_v, question, sql_ques):
pass
def find_vec_sqlque(self, question, k, theta, dataSetID, number):
q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)
D, I = self.index.search(q_embeddings, k)
result = []
for i in range(k):
sim_i = I[0][i]
dbid = self.i2dbid.get(sim_i, "none") # 获取数据库id
sqlid = self.i2sqlid.get(sim_i, "none")
question = self.id2que.get(sqlid, "none")
sql = self.id2sql.get(sqlid, "none")
if dbid == dataSetID:
sim_v = D[0][i]
score = int(sim_v * 1000)
if score < theta:
doc = {}
doc["score"] = score
doc["question"] = question
doc["sql"] = sql
result.append(doc)
if len(result) == number:
break
return result
if __name__ == '__main__':
modelpath = "E:\\module\\bge-large-zh-v1.5"
model = SentenceTransformer(modelpath)
vs = SqlQuesSearch(model)
result = vs.find_vec_sqlque("查询7月18日所有的儿童观察记录?", 3, 2000, dataSetID=111)
print(result)
3.数据库DDL训练
class DdlQuesSearch(object):
def __init__(self, _model):
self.name = "DdlQuesSearch"
self.model = _model
self.instruction = "为这段内容生成表示以用于匹配文本描述:"
self.SIZE = 1024
self.index = faiss.IndexFlatL2(self.SIZE)
self.ddldata = []
self.sqlques = {}
self.i2dbid = {}
self.i2ddlid = {}
self.dbid2ddls = {}
self.id2ddl = {}
self.ddlid2dbid = {}
# self.ddldata = []
# self.sqlques_data = []
# self.document_data = []
self.load_ddldata() # 加载text数据
self.load_ddl_vec() # text数据向量化
def load_ddldata(self):
try:
response = requests.post(
url="xxx",
verify=False)
print(response.text)
jsonobj = json.loads(response.text)
for database in databases:
db_id = database["dataSetID"]
ddls = database["ddl"]
self.ddldata.append((db_id, ddls))
# print(db_id)
# for ddl in database["ddl"]:
# ddl_id = ddl["ddl_id"]
# ddl = ddl['ddl']
#
# self.id2ddl[ddl_id] = ddl
# self.dbid2ddls[db_id] = self.id2ddl
except Exception as e:
print(e)
# print("load textdata ", self.ddldata)
def load_ddl_vec(self):
num0 = 0
for recode in self.ddldata:
db_id = recode[0]
ddls = recode[1]
for ddl in ddls:
ddl_id = ddl["ddl_id"]
ddl_name = ddl["TABLE"]
ddl = ddl['ddl']
ddl_embeddings = self.model.encode([ddl], normalize_embeddings=True)
self.index.add(ddl_embeddings)
self.i2dbid[num0] = db_id
self.i2ddlid[num0] = ddl_id
self.id2ddl[ddl_id] = ddl
self.ddlid2dbid[ddl_id] = db_id
num0 += 1
self.dbid2ddls[db_id] = self.id2ddl
print("init ddl vec", num0)
def find_vec_ddl(self, question, k, theata, dataSetID, number): # dataSetID:数据库id
# self.id2ddls.get(action_id)
q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)
D, I = self.index.search(q_embeddings, k)
result = []
for i in range(k):
sim_i = I[0][i]
dbid = self.i2dbid.get(sim_i, "none") # 获取数据库id
ddlid = self.i2ddlid.get(sim_i, "none")
if dbid == dataSetID:
sim_v = D[0][i]
score = int(sim_v * 1000)
if score < theata:
doc = {}
doc["score"] = score
doc["ddl"] = self.id2ddl.get(ddlid, "none")
result.append(doc)
if len(result) == number:
break
return result
if __name__ == '__main__':
modelpath = "E:\\module\\bge-large-zh-v1.5"
model = SentenceTransformer(modelpath)
vs = DdlQuesSearch(model)
ss = vs.find_vec_ddl("定时任务执行记录表", 2, 2000, 111)
print(ss)
4.数据库document训练
class DocQuesSearch(object):
def __init__(self):
self.name = "TestDataSearch"
self.docdata = []
self.load_doc_data()
def load_doc_data(self):
try:
response = requests.post(
url="xxx",
verify=False)
print(response.text)
jsonobj = json.loads(response.text)
databases = jsonobj["data"]
for database in databases:
db_id = database["dataSetID"]
doc = database["Intro"]
self.docdata.append((db_id, doc))
except Exception as e:
print(e)
# print("load ddldata ", self.docdata)
def find_similar_doc(self, dataSetID):
result = []
for recode in self.docdata:
dbid = recode[0]
doc = recode[1]
if dbid == dataSetID:
result.append(doc)
return result
if __name__ == '__main__':
docques_search = DocQuesSearch()
result = docques_search.find_similar_doc(222)
print(result)
5.生成sql语句,这里使用的qwen-max模型
import re
import random
import os, json
import dashscope
from dashscope.api_entities.dashscope_response import Message
from ddl_engine import DdlQuesSearch
from dashscope import Generation
from sqlques_engine import SqlQuesSearch
from sentence_transformers import SentenceTransformer
class Genarate(object):
def __init__(self):
self.api_key = os.environ.get('api_key')
self.model_name = os.environ.get('model')
def system_message(self, message):
return {'role': 'system', 'content': message}
def user_message(self, message):
return {'role': 'user', 'content': message}
def assistant_message(self, message):
return {'role': 'assistant', 'content': message}
def submit_prompt(self, prompt):
resp = Generation.call(
model=self.model_name,
messages=prompt,
seed=random.randint(1, 10000),
result_format='message',
api_key=self.api_key)
if resp["status_code"] == 200:
answer = resp.output.choices[0].message.content
global DEBUG_INFO
DEBUG_INFO = (prompt, answer)
return answer
else:
answer = None
return answer
def generate_sql(self, question, sql_result, ddl_result, doc_result):
prompt = self.get_sql_prompt(
question = question,
sql_result = sql_result,
ddl_result = ddl_result,
doc_result = doc_result)
print("SQL Prompt:",prompt)
llm_response = self.submit_prompt(prompt)
sql = self.extrat_sql(llm_response)
return sql
def extrat_sql(self, llm_response):
sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
return sql
sqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
return sql
sqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
return sql
sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
return sql
return llm_response
def get_sql_prompt(self, question, sql_result, ddl_result, doc_result):
initial_prompt = "You are a SQL expert. " + \
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
initial_prompt = self.add_ddl_to_prompt( initial_prompt, ddl_result)
initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_result)
initial_prompt += (
"===Response Guidelines \n"
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
"3. If the provided context is insufficient, please explain why it can't be generated. \n"
"4. Please use the most relevant table(s). \n"
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
)
message_log = [self.system_message(initial_prompt)]
message_log = self.add_sqlques_to_prompt(question, sql_result, message_log)
return message_log
def add_ddl_to_prompt(self, initial_prompt, ddl_result):
"""
:param initial_prompt:
:param ddl_result:
:return:
"""
ddl_list = [ ddl_['ddl'] for ddl_ in ddl_result]
if len(ddl_list) > 0:
initial_prompt += "\n===Tables \n"
for ddl in ddl_list:
initial_prompt += f"{ddl}\n\n"
return initial_prompt
def add_sqlques_to_prompt(self, question, sql_result, message_log):
"""
:param sql_result:
:return:
"""
if len(sql_result) > 0:
for example in sql_result:
if example is not None and "question" in example and "sql" in example:
message_log.append(self.user_message(example["question"]))
message_log.append(self.assistant_message(example["sql"]))
message_log.append(self.user_message(question))
return message_log
def add_documentation_to_prompt(self, initial_prompt, doc_result):
if len(doc_result) > 0:
initial_prompt += "\n===Additional Context \n\n"
for doc in doc_result:
initial_prompt += f"{doc}\n\n"
return initial_prompt
if __name__ == '__main__':
modelpath = "E:\\module\\bge-large-zh-v1.5"
model = SentenceTransformer(modelpath)
vs = DdlQuesSearch(model)
ss = vs.find_vec_ddl("定时任务执行记录表", 1, 2000, 111)
print(ss)
6.执行结果显示
如图可以看到正确生成了sql,可以正常执行,因为表是拉取到,没有数据,所以查询结果为空。
需要源码的同学,可以留言。