构建从头开始的响应合成: AI模型优化策略解析

在本教程中,我们将展示如何从头开始构建RAG(Retrieve-Augment-Generate)管道中的“LLM合成”组件。给定一组检索到的数据节点,我们将介绍如何合成响应,即使检索到的上下文信息超出了上下文窗口。

我们将讲解一些合成策略:

  1. 创建并精炼
  2. 树状摘要

最初我们使用OpenAI作为默认的LLM(大模型),但你可以根据需求插入任何LLM。

环境设置

我们将创建一个空的Pinecone索引,并定义必要的LlamaIndex封装/抽象,以便能够加载/索引数据并获取一个向量检索器。

# 安装必要的库
%pip install llama-index-readers-file pymupdf
%pip install llama-index-vector-stores-pinecone
%pip install llama-index-llms-openai

# 创建数据目录并下载示例数据
!mkdir data
!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"

# 加载数据
from pathlib import Path
from llama_index.readers.file import PyMuPDFReader

loader = PyMuPDFReader()
documents = loader.load(file_path="./data/llama2.pdf")

构建Pinecone索引并获取检索器

我们使用高级LlamaIndex抽象来1)将数据摄取到Pinecone中,然后2)获取一个向量检索器。注意,我们将块大小设为1024。

import pinecone
import os

api_key = os.getenv("PINECONE_API_KEY")
pinecone.init(api_key=api_key, environment="us-west1-gcp")

# 创建Pinecone索引
pinecone.create_index("quickstart", dimension=1536, metric="euclidean", pod_type="p1")
pinecone_index = pinecone.Index("quickstart")

# 初始化向量存储
from llama_index.vector_stores.pinecone import PineconeVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import StorageContext

vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
splitter = SentenceSplitter(chunk_size=1024)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(documents, transformations=[splitter], storage_context=storage_context)

retriever = index.as_retriever()

提取相关节点

我们使用检索器来获取给定用户查询的相关节点。这些节点将被传递给下面的响应合成模块。

query_str = "Can you tell me about results from RLHF using both model-based and human-based evaluation?"

retrieved_nodes = retriever.retrieve(query_str)

使用LLM进行响应合成

在此部分,我们将展示如何使用LLM和提示来构建响应合成模块。

1. 尝试一个简单的提示

我们首先尝试使用单个输入提示和LLM调用来合成响应。

from llama_index.llms.openai import OpenAI
from llama_index.core import PromptTemplate

llm = OpenAI(model="text-davinci-003")

qa_prompt = PromptTemplate(
    """\
Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {query_str}
Answer: \
"""
)

def generate_response(retrieved_nodes, query_str, qa_prompt, llm):
    context_str = "\n\n".join([r.get_content() for r in retrieved_nodes])
    fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
    response = llm.complete(fmt_qa_prompt)
    return str(response), fmt_qa_prompt

response, fmt_qa_prompt = generate_response(retrieved_nodes, query_str, qa_prompt, llm)
print(f"*****Response******:\n{response}\n\n")

2. 创建并精炼策略

为了解决上下文溢出问题,我们可以尝试通过所有节点顺序合成响应的策略。首先使用第一个节点生成一个初始响应,然后对于后续节点,使用额外的上下文来细化答案。

refine_prompt = PromptTemplate(
    """\
The original query is as follows: {query_str}
We have provided an existing answer: {existing_answer}
We have the opportunity to refine the existing answer \
(only if needed) with some more context below.
------------
{context_str}
------------
Given the new context, refine the original answer to better answer the query. \
If the context isn't useful, return the original answer.
Refined Answer: \
"""
)

def generate_response_cr(retrieved_nodes, query_str, qa_prompt, refine_prompt, llm):
    cur_response = None
    fmt_prompts = []
    for idx, node in enumerate(retrieved_nodes):
        context_str = node.get_content()
        if idx == 0:
            fmt_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
        else:
            fmt_prompt = refine_prompt.format(context_str=context_str, query_str=query_str, existing_answer=str(cur_response))
        cur_response = llm.complete(fmt_prompt)
        fmt_prompts.append(fmt_prompt)
    return str(cur_response), fmt_prompts

response, fmt_prompts = generate_response_cr(retrieved_nodes, query_str, qa_prompt, refine_prompt, llm)
print(str(response))

3. 树状摘要策略

另一种方法是尝试树状摘要策略。我们为每个节点独立生成一个答案,然后层次化地组合这些答案。

def combine_results(texts, query_str, qa_prompt, llm, cur_prompt_list, num_children=10):
    new_texts = []
    for idx in range(0, len(texts), num_children):
        text_batch = texts[idx : idx + num_children]
        context_str = "\n\n".join([t for t in text_batch])
        fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
        combined_response = llm.complete(fmt_qa_prompt)
        new_texts.append(str(combined_response))
        cur_prompt_list.append(fmt_qa_prompt)
    if len(new_texts) == 1:
        return new_texts[0]
    else:
        return combine_results(new_texts, query_str, qa_prompt, llm, num_children=num_children)

def generate_response_hs(retrieved_nodes, query_str, qa_prompt, llm, num_children=10):
    fmt_prompts = []
    node_responses = []
    for node in retrieved_nodes:
        context_str = node.get_content()
        fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
        node_response = llm.complete(fmt_qa_prompt)
        node_responses.append(node_response)
        fmt_prompts.append(fmt_qa_prompt)
    response_txt = combine_results([str(r) for r in node_responses], query_str, qa_prompt, llm, fmt_prompts, num_children=num_children)
    return response_txt, fmt_prompts

response, fmt_prompts = generate_response_hs(retrieved_nodes, query_str, qa_prompt, llm)
print(str(response))

4. 树状摘要策略的异步版本

树状摘要策略的一个优点是LLM调用可以并行化,从而大大加快响应合成速度。

import nest_asyncio
import asyncio

nest_asyncio.apply()

async def acombine_results(texts, query_str, qa_prompt, llm, cur_prompt_list, num_children=10):
    fmt_prompts = []
    for idx in range(0, len(texts), num_children):
        text_batch = texts[idx : idx + num_children]
        context_str = "\n\n".join([t for t in text_batch])
        fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
        fmt_prompts.append(fmt_qa_prompt)
        cur_prompt_list.append(fmt_qa_prompt)
    tasks = [llm.acomplete(p) for p in fmt_prompts]
    combined_responses = await asyncio.gather(*tasks)
    new_texts = [str(r) for r in combined_responses]
    if len(new_texts) == 1:
        return new_texts[0]
    else:
        return await acombine_results(new_texts, query_str, qa_prompt, llm, num_children=num_children)

async def agenerate_response_hs(retrieved_nodes, query_str, qa_prompt, llm, num_children=10):
    fmt_prompts = []
    node_responses = []
    for node in retrieved_nodes:
        context_str = node.get_content()
        fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
        fmt_prompts.append(fmt_qa_prompt)
    tasks = [llm.acomplete(p) for p in fmt_prompts]
    node_responses = await asyncio.gather(*tasks)
    response_txt = combine_results([str(r) for r in node_responses], query_str, qa_prompt, llm, fmt_prompts, num_children=num_children)
    return response_txt, fmt_prompts

response, fmt_prompts = await agenerate_response_hs(retrieved_nodes, query_str, qa_prompt, llm)
print(str(response))

整合所有内容

我们定义一个简单的查询引擎,它可以用检索器、提示、LLM等进行初始化。并实现一个简单的查询功能。

from llama_index.core.retrievers import BaseRetriever
from llama_index.core.llms import LLM
from dataclasses import dataclass
from typing import Optional, List

@dataclass
class Response:
    response: str
    source_nodes: Optional[List] = None

    def __str__(self):
        return self.response

class MyQueryEngine:
    def __init__(self, retriever: BaseRetriever, qa_prompt: PromptTemplate, llm: LLM, num_children=10):
        self._retriever = retriever
        self._qa_prompt = qa_prompt
        self._llm = llm
        self._num_children = num_children

    def query(self, query_str: str):
        retrieved_nodes = self._retriever.retrieve(query_str)
        response_txt, _ = generate_response_hs(retrieved_nodes, query_str, self._qa_prompt, self._llm, num_children=self._num_children)
        return Response(response_txt, source_nodes=retrieved_nodes)

    async def aquery(self, query_str: str):
        retrieved_nodes = await self._retriever.aretrieve(query_str)
        response_txt, _ = await agenerate_response_hs(retrieved_nodes, query_str, self._qa_prompt, self._llm, num_children=self._num_children)
        return Response(response_txt, source_nodes=retrieved_nodes)

query_engine = MyQueryEngine(retriever, qa_prompt, llm, num_children=10)
response = query_engine.query(query_str)
print(str(response))

可能遇到的错误

  1. API密钥未设置: 请确保Pinecone和OpenAI的API密钥已经正确设置。
  2. 上下文窗口溢出: 在进行提示生成时,可能会遇到上下文窗口溢出的问题,需要优化提示模板以减少信息量。

如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!

参考资料:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值