使用RankGPT进行AI文本检索重排序

在AI技术领域,文本检索是一个非常重要的应用场景。传统的文本检索方法通常依赖于关键词匹配和基于统计的模型。然而,随着大规模语言模型(如GPT-3.5和Mistral)的发展,我们可以通过引入这些模型来实现更智能、更准确的文本检索。本文将介绍如何使用RankGPT进行文本检索的重排序。

RankGPT简介

RankGPT是一种基于大规模语言模型的零样本列表式段落重排序方法。它通过生成排列和滑动窗口策略高效地重排序段落。在本文中,我们将展示如何使用OpenAI的GPT-3.5模型和Mistral模型来实现RankGPT重排序。

环境配置

在开始之前,我们需要安装一些必要的Python库:

%pip install llama-index-postprocessor-rankgpt-rerank
%pip install llama-index-llms-huggingface
%pip install llama-index-llms-openai
%pip install llama-index-llms-ollama

数据加载与索引构建

首先,我们从维基百科下载有关Van Gogh的文本,并构建向量存储索引。

import nest_asyncio
import logging
import sys
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.llms.openai import OpenAI
from pathlib import Path
import requests
import os

nest_asyncio.apply()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

OPENAI_API_TOKEN = "sk-"
os.environ["OPENAI_API_KEY"] = OPENAI_API_TOKEN

# 下载数据
wiki_titles = ["Vincent van Gogh"]
data_path = Path("data_wiki")
for title in wiki_titles:
    response = requests.get(
        "https://en.wikipedia.org/w/api.php",
        params={
            "action": "query",
            "format": "json",
            "titles": title,
            "prop": "extracts",
            "explaintext": True,
        },
    ).json()
    page = next(iter(response["query"]["pages"].values()))
    wiki_text = page["extract"]
    if not data_path.exists():
        Path.mkdir(data_path)
    with open(data_path / f"{title}.txt", "w") as fp:
        fp.write(wiki_text)

# 加载文档
documents = SimpleDirectoryReader("./data_wiki/").load_data()

# 构建索引
index = VectorStoreIndex.from_documents(documents)

检索与重排序

接下来,我们设置检索器和RankGPT重排序器,并比较有无重排序的检索结果。

from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core import QueryBundle
from llama_index.postprocessor.rankgpt_rerank import RankGPTRerank
import pandas as pd
from IPython.display import display, HTML

def get_retrieved_nodes(query_str, vector_top_k=10, reranker_top_n=3, with_reranker=False):
    query_bundle = QueryBundle(query_str)
    retriever = VectorIndexRetriever(index=index, similarity_top_k=vector_top_k)
    retrieved_nodes = retriever.retrieve(query_bundle)

    if with_reranker:
        reranker = RankGPTRerank(
            llm=OpenAI(model="gpt-3.5-turbo-16k", temperature=0.0, api_key=OPENAI_API_TOKEN),
            top_n=reranker_top_n,
            verbose=True,
        )
        retrieved_nodes = reranker.postprocess_nodes(retrieved_nodes, query_bundle)

    return retrieved_nodes

def pretty_print(df):
    return display(HTML(df.to_html().replace("\\n", "<br>")))

def visualize_retrieved_nodes(nodes):
    result_dicts = []
    for node in nodes:
        result_dict = {"Score": node.score, "Text": node.node.get_text()}
        result_dicts.append(result_dict)
    pretty_print(pd.DataFrame(result_dicts))

# 不使用重排序进行检索
new_nodes = get_retrieved_nodes("Which date did Paul Gauguin arrive in Arles?", vector_top_k=3, with_reranker=False)
visualize_retrieved_nodes(new_nodes)

# 使用RankGPT进行重排序
new_nodes = get_retrieved_nodes("Which date did Paul Gauguin arrive in Arles?", vector_top_k=10, reranker_top_n=3, with_reranker=True)
visualize_retrieved_nodes(new_nodes)

使用其他LLM进行RankGPT重排序

我们还可以使用Ollama提供的本地Mistral模型进行RankGPT重排序。

from llama_index.llms.ollama import Ollama

llm = Ollama(model="mistral", request_timeout=30.0)

def get_retrieved_nodes_with_ollama(query_str, vector_top_k=5, reranker_top_n=3, with_reranker=False):
    query_bundle = QueryBundle(query_str)
    retriever = VectorIndexRetriever(index=index, similarity_top_k=vector_top_k)
    retrieved_nodes = retriever.retrieve(query_bundle)

    if with_reranker:
        reranker = RankGPTRerank(llm=llm, top_n=reranker_top_n, verbose=True)
        retrieved_nodes = reranker.postprocess_nodes(retrieved_nodes, query_bundle)

    return retrieved_nodes

new_nodes = get_retrieved_nodes_with_ollama("Which date did Paul Gauguin arrive in Arles?", vector_top_k=10, reranker_top_n=3, with_reranker=True)
visualize_retrieved_nodes(new_nodes)

可能遇到的错误

  1. 网络错误:在请求外部API时,可能会遇到网络连接错误。这时可以尝试检查网络连接,或者稍后重试。
  2. API密钥错误:使用OpenAI或其他LLM提供的服务时,API密钥不正确或过期会导致请求失败。请确保API密钥有效并正确配置。
  3. 数据加载错误:在下载维基百科数据时,可能会遇到请求失败或数据格式错误。这时可以尝试重新下载数据或检查数据格式。

如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!

参考资料:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值