深入解析 LLMRerank:一种基于LLM的节点重排序器
在自然语言处理(NLP)领域,节点重排序是一个关键的步骤。它涉及对检索到的节点进行重新排序,以便更好地匹配查询需求。今天,我们将深入探讨一种名为 LLMRerank
的节点重排序器,它利用大型语言模型(LLM)来对节点进行重排序。这种重排序器在处理复杂查询时尤为有用,因为它可以帮助我们更好地理解节点的相关性。
前置知识
在深入了解 LLMRerank
之前,我们需要掌握以下几个概念:
- 节点(Node):在NLP中,节点是文档的基本单元。它可以是一个句子、一个段落或一个词语。
- 大型语言模型(LLM):一种基于深度学习的模型,能够理解和生成自然语言文本。
- 提示模板(Prompt Template):一种用于生成LLM输入的模板,通常包含占位符,用于插入动态内容。
- 查询包(Query Bundle):包含查询字符串和其他相关信息的包,用于指导节点的检索和重排序。
LLMRerank 的实现
LLMRerank
是一个基于 BaseNodePostprocessor
接口的类,它通过利用LLM对节点进行重排序。下面是其实现的详细解析:
导入必要的模块
首先,我们需要导入一些必要的模块和函数:
from typing import Callable, List, Optional
from llama_index.core.bridge.pydantic import Field, PrivateAttr, SerializeAsAny
from llama_index.core.indices.utils import (
default_format_node_batch_fn,
default_parse_choice_select_answer_fn,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.prompts import BasePromptTemplate
from llama_index.core.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT
from llama_index.core.prompts.mixin import PromptDictType
from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.core.settings import Settings
定义 LLMRerank 类
现在,我们定义 LLMRerank
类,并为其添加必要的属性和方法:
class LLMRerank(BaseNodePostprocessor):
"""LLM-based reranker."""
top_n: int = Field(description="Top N nodes to return.")
choice_select_prompt: SerializeAsAny[BasePromptTemplate] = Field(
description="Choice select prompt."
)
choice_batch_size: int = Field(description="Batch size for choice select.")
llm: LLM = Field(description="The LLM to rerank with.")
_format_node_batch_fn: Callable = PrivateAttr()
_parse_choice_select_answer_fn: Callable = PrivateAttr()
def __init__(
self,
llm: Optional[LLM] = None,
choice_select_prompt: Optional[BasePromptTemplate] = None,
choice_batch_size: int = 10,
format_node_batch_fn: Optional[Callable] = None,
parse_choice_select_answer_fn: Optional[Callable] = None,
top_n: int = 10,
) -> None:
choice_select_prompt = choice_select_prompt or DEFAULT_CHOICE_SELECT_PROMPT
llm = llm or Settings.llm
super().__init__(
llm=llm,
choice_select_prompt=choice_select_prompt,
choice_batch_size=choice_batch_size,
top_n=top_n,
)
self._format_node_batch_fn = (
format_node_batch_fn or default_format_node_batch_fn
)
self._parse_choice_select_answer_fn = (
parse_choice_select_answer_fn or default_parse_choice_select_answer_fn
)
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {"choice_select_prompt": self.choice_select_prompt}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "choice_select_prompt" in prompts:
self.choice_select_prompt = prompts["choice_select_prompt"]
@classmethod
def class_name(cls) -> str:
return "LLMRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Query bundle must be provided.")
if len(nodes) == 0:
return []
initial_results: List[NodeWithScore] = []
for idx in range(0, len(nodes), self.choice_batch_size):
nodes_batch = [
node.node for node in nodes[idx : idx + self.choice_batch_size]
]
query_str = query_bundle.query_str
fmt_batch_str = self._format_node_batch_fn(nodes_batch)
# call each batch independently
raw_response = self.llm.predict(
self.choice_select_prompt,
context_str=fmt_batch_str,
query_str=query_str,
)
raw_choices, relevances = self._parse_choice_select_answer_fn(
raw_response, len(nodes_batch)
)
choice_idxs = [int(choice) - 1 for choice in raw_choices]
choice_nodes = [nodes_batch[idx] for idx in choice_idxs]
relevances = relevances or [1.0 for _ in choice_nodes]
initial_results.extend(
[
NodeWithScore(node=node, score=relevance)
for node, relevance in zip(choice_nodes, relevances)
]
)
return sorted(initial_results, key=lambda x: x.score or 0.0, reverse=True)[
: self.top_n
]
详细解析
属性解析
- top_n:返回的前N个节点。
- choice_select_prompt:用于选择节点的提示模板。
- choice_batch_size:选择节点的批处理大小。
- llm:用于重排序的LLM。
方法解析
- init:初始化方法,用于设置默认参数和函数。
- _get_prompts:获取提示模板。
- _update_prompts:更新提示模板。
- class_name:返回类的名称。
- _postprocess_nodes:对节点进行重排序的核心方法。
实际应用示例
为了更好地理解 LLMRerank
的工作原理,我们来看一个实际的应用示例:
# 示例查询包
query_bundle = QueryBundle(query_str="什么是自然语言处理?")
# 示例节点列表
nodes = [
NodeWithScore(node=Node(text="自然语言处理是计算机科学的一个分支。"), score=0.8),
NodeWithScore(node=Node(text="它涉及计算机与人类语言之间的交互。"), score=0.7),
NodeWithScore(node=Node(text="自然语言处理使用机器学习算法。"), score=0.9),
]
# 创建 LLMRerank 实例
llm_rerank = LLMRerank(top_n=2)
# 对节点进行重排序
reranked_nodes = llm_rerank._postprocess_nodes(nodes, query_bundle)
# 输出重排序后的节点
for node in reranked_nodes:
print(f"Node: {node.node.text}, Score: {node.score}")
在这个示例中,我们首先定义了一个查询包和节点列表,然后使用 LLMRerank
对节点进行重排序,并输出重排序后的节点及其分数。
输出结果
假设重排序后的节点如下:
- “自然语言处理使用机器学习算法。”(分数:0.9)
- “自然语言处理是计算机科学的一个分支。”(分数:0.8)
那么,输出的节点信息可能如下:
Node: 自然语言处理使用机器学习算法。, Score: 0.9
Node: 自然语言处理是计算机科学的一个分支。, Score: 0.8
总结
LLMRerank
是一个高效的节点重排序器,它利用大型语言模型(LLM)对节点进行重排序,以便更好地匹配查询需求。通过这种方式,我们可以更好地理解节点的相关性,从而提高NLP任务的准确性。希望这篇博客能够帮助你全面理解 LLMRerank
的工作原理及实际应用。