在本教程中,我们将展示如何从头开始构建RAG(Retrieve-Augment-Generate)管道中的“LLM合成”组件。给定一组检索到的数据节点,我们将介绍如何合成响应,即使检索到的上下文信息超出了上下文窗口。
我们将讲解一些合成策略:
- 创建并精炼
- 树状摘要
最初我们使用OpenAI作为默认的LLM(大模型),但你可以根据需求插入任何LLM。
环境设置
我们将创建一个空的Pinecone索引,并定义必要的LlamaIndex封装/抽象,以便能够加载/索引数据并获取一个向量检索器。
# 安装必要的库
%pip install llama-index-readers-file pymupdf
%pip install llama-index-vector-stores-pinecone
%pip install llama-index-llms-openai
# 创建数据目录并下载示例数据
!mkdir data
!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"
# 加载数据
from pathlib import Path
from llama_index.readers.file import PyMuPDFReader
loader = PyMuPDFReader()
documents = loader.load(file_path="./data/llama2.pdf")
构建Pinecone索引并获取检索器
我们使用高级LlamaIndex抽象来1)将数据摄取到Pinecone中,然后2)获取一个向量检索器。注意,我们将块大小设为1024。
import pinecone
import os
api_key = os.getenv("PINECONE_API_KEY")
pinecone.init(api_key=api_key, environment="us-west1-gcp")
# 创建Pinecone索引
pinecone.create_index("quickstart", dimension=1536, metric="euclidean", pod_type="p1")
pinecone_index = pinecone.Index("quickstart")
# 初始化向量存储
from llama_index.vector_stores.pinecone import PineconeVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import StorageContext
vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
splitter = SentenceSplitter(chunk_size=1024)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(documents, transformations=[splitter], storage_context=storage_context)
retriever = index.as_retriever()
提取相关节点
我们使用检索器来获取给定用户查询的相关节点。这些节点将被传递给下面的响应合成模块。
query_str = "Can you tell me about results from RLHF using both model-based and human-based evaluation?"
retrieved_nodes = retriever.retrieve(query_str)
使用LLM进行响应合成
在此部分,我们将展示如何使用LLM和提示来构建响应合成模块。
1. 尝试一个简单的提示
我们首先尝试使用单个输入提示和LLM调用来合成响应。
from llama_index.llms.openai import OpenAI
from llama_index.core import PromptTemplate
llm = OpenAI(model="text-davinci-003")
qa_prompt = PromptTemplate(
"""\
Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {query_str}
Answer: \
"""
)
def generate_response(retrieved_nodes, query_str, qa_prompt, llm):
context_str = "\n\n".join([r.get_content() for r in retrieved_nodes])
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
response = llm.complete(fmt_qa_prompt)
return str(response), fmt_qa_prompt
response, fmt_qa_prompt = generate_response(retrieved_nodes, query_str, qa_prompt, llm)
print(f"*****Response******:\n{response}\n\n")
2. 创建并精炼策略
为了解决上下文溢出问题,我们可以尝试通过所有节点顺序合成响应的策略。首先使用第一个节点生成一个初始响应,然后对于后续节点,使用额外的上下文来细化答案。
refine_prompt = PromptTemplate(
"""\
The original query is as follows: {query_str}
We have provided an existing answer: {existing_answer}
We have the opportunity to refine the existing answer \
(only if needed) with some more context below.
------------
{context_str}
------------
Given the new context, refine the original answer to better answer the query. \
If the context isn't useful, return the original answer.
Refined Answer: \
"""
)
def generate_response_cr(retrieved_nodes, query_str, qa_prompt, refine_prompt, llm):
cur_response = None
fmt_prompts = []
for idx, node in enumerate(retrieved_nodes):
context_str = node.get_content()
if idx == 0:
fmt_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
else:
fmt_prompt = refine_prompt.format(context_str=context_str, query_str=query_str, existing_answer=str(cur_response))
cur_response = llm.complete(fmt_prompt)
fmt_prompts.append(fmt_prompt)
return str(cur_response), fmt_prompts
response, fmt_prompts = generate_response_cr(retrieved_nodes, query_str, qa_prompt, refine_prompt, llm)
print(str(response))
3. 树状摘要策略
另一种方法是尝试树状摘要策略。我们为每个节点独立生成一个答案,然后层次化地组合这些答案。
def combine_results(texts, query_str, qa_prompt, llm, cur_prompt_list, num_children=10):
new_texts = []
for idx in range(0, len(texts), num_children):
text_batch = texts[idx : idx + num_children]
context_str = "\n\n".join([t for t in text_batch])
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
combined_response = llm.complete(fmt_qa_prompt)
new_texts.append(str(combined_response))
cur_prompt_list.append(fmt_qa_prompt)
if len(new_texts) == 1:
return new_texts[0]
else:
return combine_results(new_texts, query_str, qa_prompt, llm, num_children=num_children)
def generate_response_hs(retrieved_nodes, query_str, qa_prompt, llm, num_children=10):
fmt_prompts = []
node_responses = []
for node in retrieved_nodes:
context_str = node.get_content()
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
node_response = llm.complete(fmt_qa_prompt)
node_responses.append(node_response)
fmt_prompts.append(fmt_qa_prompt)
response_txt = combine_results([str(r) for r in node_responses], query_str, qa_prompt, llm, fmt_prompts, num_children=num_children)
return response_txt, fmt_prompts
response, fmt_prompts = generate_response_hs(retrieved_nodes, query_str, qa_prompt, llm)
print(str(response))
4. 树状摘要策略的异步版本
树状摘要策略的一个优点是LLM调用可以并行化,从而大大加快响应合成速度。
import nest_asyncio
import asyncio
nest_asyncio.apply()
async def acombine_results(texts, query_str, qa_prompt, llm, cur_prompt_list, num_children=10):
fmt_prompts = []
for idx in range(0, len(texts), num_children):
text_batch = texts[idx : idx + num_children]
context_str = "\n\n".join([t for t in text_batch])
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
fmt_prompts.append(fmt_qa_prompt)
cur_prompt_list.append(fmt_qa_prompt)
tasks = [llm.acomplete(p) for p in fmt_prompts]
combined_responses = await asyncio.gather(*tasks)
new_texts = [str(r) for r in combined_responses]
if len(new_texts) == 1:
return new_texts[0]
else:
return await acombine_results(new_texts, query_str, qa_prompt, llm, num_children=num_children)
async def agenerate_response_hs(retrieved_nodes, query_str, qa_prompt, llm, num_children=10):
fmt_prompts = []
node_responses = []
for node in retrieved_nodes:
context_str = node.get_content()
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
fmt_prompts.append(fmt_qa_prompt)
tasks = [llm.acomplete(p) for p in fmt_prompts]
node_responses = await asyncio.gather(*tasks)
response_txt = combine_results([str(r) for r in node_responses], query_str, qa_prompt, llm, fmt_prompts, num_children=num_children)
return response_txt, fmt_prompts
response, fmt_prompts = await agenerate_response_hs(retrieved_nodes, query_str, qa_prompt, llm)
print(str(response))
整合所有内容
我们定义一个简单的查询引擎,它可以用检索器、提示、LLM等进行初始化。并实现一个简单的查询功能。
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.llms import LLM
from dataclasses import dataclass
from typing import Optional, List
@dataclass
class Response:
response: str
source_nodes: Optional[List] = None
def __str__(self):
return self.response
class MyQueryEngine:
def __init__(self, retriever: BaseRetriever, qa_prompt: PromptTemplate, llm: LLM, num_children=10):
self._retriever = retriever
self._qa_prompt = qa_prompt
self._llm = llm
self._num_children = num_children
def query(self, query_str: str):
retrieved_nodes = self._retriever.retrieve(query_str)
response_txt, _ = generate_response_hs(retrieved_nodes, query_str, self._qa_prompt, self._llm, num_children=self._num_children)
return Response(response_txt, source_nodes=retrieved_nodes)
async def aquery(self, query_str: str):
retrieved_nodes = await self._retriever.aretrieve(query_str)
response_txt, _ = await agenerate_response_hs(retrieved_nodes, query_str, self._qa_prompt, self._llm, num_children=self._num_children)
return Response(response_txt, source_nodes=retrieved_nodes)
query_engine = MyQueryEngine(retriever, qa_prompt, llm, num_children=10)
response = query_engine.query(query_str)
print(str(response))
可能遇到的错误
- API密钥未设置: 请确保Pinecone和OpenAI的API密钥已经正确设置。
- 上下文窗口溢出: 在进行提示生成时,可能会遇到上下文窗口溢出的问题,需要优化提示模板以减少信息量。
如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!
参考资料: