使用LlamaIndex实现文本到SQL的高级查询管道

在本文中,我们将向您展示如何使用LlamaIndex设置一个文本到SQL的查询管道。这将使您能够在数据上动态检索相关表,并在SQL提示中嵌入/索引每一行,从而在文本到SQL的提示中动态检索示例行。

加载和摄取数据

我们使用WikiTableQuestions数据集作为我们的测试数据集。以下是加载和摄取数据的步骤:

# 安装必要的包
%pip install llama-index-llms-openai

# 下载和解压数据集
!wget "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip" -O data.zip
!unzip data.zip

# 加载数据
import pandas as pd
from pathlib import Path

data_dir = Path("./WikiTableQuestions/csv/200-csv")
csv_files = sorted([f for f in data_dir.glob("*.csv")])
dfs = []
for csv_file in csv_files:
    print(f"processing file: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        dfs.append(df)
    except Exception as e:
        print(f"Error parsing {csv_file}: {str(e)}")

提取每个表的表名和摘要

我们使用GPT-3.5提取每个表的表名和摘要。

# 创建目录存储表信息
tableinfo_dir = "WikiTableQuestions_TableInfo"
!mkdir {tableinfo_dir}

from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.llms.openai import OpenAI

class TableInfo(BaseModel):
    """表格信息"""
    table_name: str = Field(..., description="表名,必须使用下划线并且不能有空格")
    table_summary: str = Field(..., description="表的简短摘要/标题")

prompt_str = """\
给我一个表格摘要,格式如下:
- 表名必须唯一,并且简洁描述表格内容。
- 不要输出通用的表名(例如:table, my_table)。
排除以下表名:{exclude_table_name_list}
表格内容:
{table_str}
摘要:"""

program = LLMTextCompletionProgram.from_defaults(
    output_cls=TableInfo,
    llm=OpenAI(model="gpt-3.5-turbo"),
    prompt_template_str=prompt_str,
)

import json

def _get_tableinfo_with_index(idx: int) -> str:
    results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
    results_list = list(results_gen)
    if len(results_list) == 0:
        return None
    elif len(results_list) == 1:
        path = results_list[0]
        return TableInfo.parse_file(path)
    else:
        raise ValueError(f"More than one file matching index: {list(results_gen)}")

table_names = set()
table_infos = []
for idx, df in enumerate(dfs):
    table_info = _get_tableinfo_with_index(idx)
    if table_info:
        table_infos.append(table_info)
    else:
        while True:
            df_str = df.head(10).to_csv()
            table_info = program(
                table_str=df_str,
                exclude_table_name_list=str(list(table_names)),
            )
            table_name = table_info.table_name
            print(f"Processed table: {table_name}")
            if table_name not in table_names:
                table_names.add(table_name)
                break
            else:
                print(f"Table name {table_name} already exists, trying again.")
                pass

        out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
        json.dump(table_info.dict(), open(out_file, "w"))
    table_infos.append(table_info)

将数据放入SQL数据库

我们使用SQLAlchemy将所有表加载到一个SQLite数据库中。

from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer
import re

def sanitize_column_name(col_name):
    return re.sub(r"\W+", "_", col_name)

def create_table_from_dataframe(df: pd.DataFrame, table_name: str, engine, metadata_obj):
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]
    table = Table(table_name, metadata_obj, *columns)
    metadata_obj.create_all(engine)
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
for idx, df in enumerate(dfs):
    tableinfo = _get_tableinfo_with_index(idx)
    print(f"Creating table: {tableinfo.table_name}")
    create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)

高级功能1:查询时间表检索的文本到SQL

定义模块以设置端到端的文本到SQL与表检索。

from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema
from llama_index.core import SQLDatabase, VectorStoreIndex

sql_database = SQLDatabase(engine)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
    for t in table_infos
]

obj_index = ObjectIndex.from_objects(table_schema_objs, table_node_mapping, VectorStoreIndex)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)

运行一些查询

现在我们准备好在整个管道上运行一些查询。

response = qp.run(query="What was the year that The Notorious B.I.G was signed to Bad Boy?")
print(str(response))

可能遇到的错误

  1. 数据解析错误:在加载CSV文件时,可能会遇到解析错误,确保数据文件格式正确。
  2. SQL查询错误:生成的SQL查询可能包含不存在的列名或表名,确保提示模板正确处理。

如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值