深入解析BM25Retriever的持久化与检索方法:实现高效的数据存储与查询
在前两篇文章中,我们详细解析了BM25Retriever
类的初始化方法和from_defaults
类方法。本文将继续深入探讨该类的持久化与检索方法,包括get_persist_args
、persist
、from_persist_dir
和_retrieve
方法。通过这些方法,程序员可以高效地存储和检索数据,提升系统的性能和可维护性。
前置知识
在继续之前,确保您已经熟悉以下概念:
- 持久化(Persistence):将数据存储到持久存储(如硬盘)中的过程,以便在程序重启后可以恢复数据。
- JSON:一种轻量级的数据交换格式,易于人阅读和编写,也易于机器解析和生成。
- QueryBundle:表示查询的封装类,包含查询字符串等信息。
- NodeWithScore:表示带有分数的节点类,用于存储检索结果。
方法解析
get_persist_args方法
def get_persist_args(self) -> Dict[str, Any]:
"""Get Persist Args Dict to Save."""
return {
DEFAULT_PERSIST_ARGS[key]: getattr(self, key)
for key in DEFAULT_PERSIST_ARGS
if hasattr(self, key)
}
代码解析
-
功能:
- 获取需要持久化的参数字典。
-
实现:
- 使用字典推导式,遍历
DEFAULT_PERSIST_ARGS
中的键。 - 检查当前对象是否具有该属性,如果有则获取该属性的值,并将其添加到返回的字典中。
- 使用字典推导式,遍历
persist方法
def persist(self, path: str, **kwargs: Any) -> None:
"""Persist the retriever to a directory."""
self.bm25.save(path, corpus=self.corpus, **kwargs)
with open(os.path.join(path, DEFAULT_PERSIST_FILENAME), "w") as f:
json.dump(self.get_persist_args(), f, indent=2)
代码解析
-
功能:
- 将检索器持久化到指定目录。
-
实现:
- 调用
bm25
对象的save
方法,将BM25对象和语料库保存到指定路径。 - 打开指定路径下的文件,使用
json.dump
方法将持久化参数字典写入文件。
- 调用
from_persist_dir类方法
@classmethod
def from_persist_dir(cls, path: str, **kwargs: Any) -> "BM25Retriever":
"""Load the retriever from a directory."""
bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
with open(os.path.join(path, DEFAULT_PERSIST_FILENAME)) as f:
retriever_data = json.load(f)
return cls(existing_bm25=bm25, **retriever_data)
代码解析
-
功能:
- 从指定目录加载检索器。
-
实现:
- 调用
bm25s.BM25.load
方法,从指定路径加载BM25对象和语料库。 - 打开指定路径下的文件,使用
json.load
方法读取持久化参数字典。 - 使用加载的BM25对象和参数字典,调用类的初始化方法创建
BM25Retriever
实例。
- 调用
_retrieve方法
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
query = query_bundle.query_str
tokenized_query = bm25s.tokenize(
query, stemmer=self.stemmer, show_progress=self._verbose
)
indexes, scores = self.bm25.retrieve(
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
)
# batched, but only one query
indexes = indexes[0]
scores = scores[0]
nodes: List[NodeWithScore] = []
for idx, score in zip(indexes, scores):
# idx can be an int or a dict of the node
if isinstance(idx, dict):
node = metadata_dict_to_node(idx)
else:
node_dict = self.corpus[int(idx)]
node = metadata_dict_to_node(node_dict)
nodes.append(NodeWithScore(node=node, score=float(score)))
return nodes
代码解析
-
功能:
- 根据查询字符串检索相关节点。
-
实现:
- 从
query_bundle
中获取查询字符串。 - 使用
bm25s.tokenize
方法对查询字符串进行分词处理。 - 调用
bm25
对象的retrieve
方法,获取检索结果的索引和分数。 - 处理检索结果,将索引和分数转换为
NodeWithScore
对象列表。 - 返回检索结果列表。
- 从
示例代码
假设我们已经有一个持久化的检索器,并希望从持久化目录中加载并进行查询:
from some_module import BM25Retriever, QueryBundle
# 从持久化目录加载检索器
retriever = BM25Retriever.from_persist_dir(path="path_to_persist_dir")
# 进行查询
query_bundle = QueryBundle(query_str="sample document")
results = retriever._retrieve(query_bundle)
# 输出结果
for result in results:
print(result)
代码解释
-
加载检索器:
- 使用
from_persist_dir
类方法从指定路径加载持久化的检索器。
- 使用
-
进行查询:
- 创建
QueryBundle
对象,包含查询字符串“sample document”。 - 调用
_retrieve
方法进行查询,获取检索结果。
- 创建
-
输出结果:
- 遍历检索结果并打印。
总结
通过本文的详细解析,我们深入理解了BM25Retriever
类的持久化与检索方法。这些方法提供了高效的数据存储和查询功能,使得系统在处理大量数据时更加稳定和高效。通过提供必要的代码示例和解释,帮助程序员快速掌握并应用这一高效的检索技术。希望本文能为您的编程实践提供有益的参考和指导。
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, cast
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.core.indices.vector_store.base import VectorStoreIndex
from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
from llama_index.core.storage.docstore.types import BaseDocumentStore
from llama_index.core.vector_stores.utils import (
node_to_metadata_dict,
metadata_dict_to_node,
)
import bm25s
import Stemmer
logger = logging.getLogger(__name__)
DEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}
DEFAULT_PERSIST_FILENAME = "retriever.json"
class BM25Retriever(BaseRetriever):
"""A BM25 retriever that uses the BM25 algorithm to retrieve nodes.
Args:
nodes (List[BaseNode], optional):
The nodes to index. If not provided, an existing BM25 object must be passed.
stemmer (Stemmer.Stemmer, optional):
The stemmer to use. Defaults to an english stemmer.
language (str, optional):
The language to use for stopword removal. Defaults to "en".
existing_bm25 (bm25s.BM25, optional):
An existing BM25 object to use. If not provided, nodes must be passed.
similarity_top_k (int, optional):
The number of results to return. Defaults to DEFAULT_SIMILARITY_TOP_K.
callback_manager (CallbackManager, optional):
The callback manager to use. Defaults to None.
objects (List[IndexNode], optional):
The objects to retrieve. Defaults to None.
object_map (dict, optional):
A map of object IDs to nodes. Defaults to None.
verbose (bool, optional):
Whether to show progress. Defaults to False.
"""
def __init__(
self,
nodes: Optional[List[BaseNode]] = None,
stemmer: Optional[Stemmer.Stemmer] = None,
language: str = "en",
existing_bm25: Optional[bm25s.BM25] = None,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
callback_manager: Optional[CallbackManager] = None,
objects: Optional[List[IndexNode]] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
) -> None:
self.stemmer = stemmer or Stemmer.Stemmer("english")
self.similarity_top_k = similarity_top_k
if existing_bm25 is not None:
self.bm25 = existing_bm25
self.corpus = existing_bm25.corpus
else:
if nodes is None:
raise ValueError("Please pass nodes or an existing BM25 object.")
self.corpus = [node_to_metadata_dict(node) for node in nodes]
corpus_tokens = bm25s.tokenize(
[node.get_content() for node in nodes],
stopwords=language,
stemmer=self.stemmer,
show_progress=verbose,
)
self.bm25 = bm25s.BM25()
self.bm25.index(corpus_tokens, show_progress=verbose)
super().__init__(
callback_manager=callback_manager,
object_map=object_map,
objects=objects,
verbose=verbose,
)
@classmethod
def from_defaults(
cls,
index: Optional[VectorStoreIndex] = None,
nodes: Optional[List[BaseNode]] = None,
docstore: Optional[BaseDocumentStore] = None,
stemmer: Optional[Stemmer.Stemmer] = None,
language: str = "en",
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
verbose: bool = False,
# deprecated
tokenizer: Optional[Callable[[str], List[str]]] = None,
) -> "BM25Retriever":
if tokenizer is not None:
logger.warning(
"The tokenizer parameter is deprecated and will be removed in a future release. "
"Use a stemmer from PyStemmer instead."
)
# ensure only one of index, nodes, or docstore is passed
if sum(bool(val) for val in [index, nodes, docstore]) != 1:
raise ValueError("Please pass exactly one of index, nodes, or docstore.")
if index is not None:
docstore = index.docstore
if docstore is not None:
nodes = cast(List[BaseNode], list(docstore.docs.values()))
assert (
nodes is not None
), "Please pass exactly one of index, nodes, or docstore."
return cls(
nodes=nodes,
stemmer=stemmer,
language=language,
similarity_top_k=similarity_top_k,
verbose=verbose,
)
def get_persist_args(self) -> Dict[str, Any]:
"""Get Persist Args Dict to Save."""
return {
DEFAULT_PERSIST_ARGS[key]: getattr(self, key)
for key in DEFAULT_PERSIST_ARGS
if hasattr(self, key)
}
def persist(self, path: str, **kwargs: Any) -> None:
"""Persist the retriever to a directory."""
self.bm25.save(path, corpus=self.corpus, **kwargs)
with open(os.path.join(path, DEFAULT_PERSIST_FILENAME), "w") as f:
json.dump(self.get_persist_args(), f, indent=2)
@classmethod
def from_persist_dir(cls, path: str, **kwargs: Any) -> "BM25Retriever":
"""Load the retriever from a directory."""
bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
with open(os.path.join(path, DEFAULT_PERSIST_FILENAME)) as f:
retriever_data = json.load(f)
return cls(existing_bm25=bm25, **retriever_data)
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
query = query_bundle.query_str
tokenized_query = bm25s.tokenize(
query, stemmer=self.stemmer, show_progress=self._verbose
)
indexes, scores = self.bm25.retrieve(
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
)
# batched, but only one query
indexes = indexes[0]
scores = scores[0]
nodes: List[NodeWithScore] = []
for idx, score in zip(indexes, scores):
# idx can be an int or a dict of the node
if isinstance(idx, dict):
node = metadata_dict_to_node(idx)
else:
node_dict = self.corpus[int(idx)]
node = metadata_dict_to_node(node_dict)
nodes.append(NodeWithScore(node=node, score=float(score)))
return nodes