125 深入解析BM25Retriever的持久化与检索方法:实现高效的数据存储与查询

深入解析BM25Retriever的持久化与检索方法:实现高效的数据存储与查询

在前两篇文章中,我们详细解析了BM25Retriever类的初始化方法和from_defaults类方法。本文将继续深入探讨该类的持久化与检索方法,包括get_persist_argspersistfrom_persist_dir_retrieve方法。通过这些方法,程序员可以高效地存储和检索数据,提升系统的性能和可维护性。

前置知识

在继续之前,确保您已经熟悉以下概念:

  1. 持久化(Persistence):将数据存储到持久存储(如硬盘)中的过程,以便在程序重启后可以恢复数据。
  2. JSON:一种轻量级的数据交换格式,易于人阅读和编写,也易于机器解析和生成。
  3. QueryBundle:表示查询的封装类,包含查询字符串等信息。
  4. NodeWithScore:表示带有分数的节点类,用于存储检索结果。

方法解析

get_persist_args方法

def get_persist_args(self) -> Dict[str, Any]:
    """Get Persist Args Dict to Save."""
    return {
        DEFAULT_PERSIST_ARGS[key]: getattr(self, key)
        for key in DEFAULT_PERSIST_ARGS
        if hasattr(self, key)
    }
代码解析
  1. 功能

    • 获取需要持久化的参数字典。
  2. 实现

    • 使用字典推导式,遍历DEFAULT_PERSIST_ARGS中的键。
    • 检查当前对象是否具有该属性,如果有则获取该属性的值,并将其添加到返回的字典中。

persist方法

def persist(self, path: str, **kwargs: Any) -> None:
    """Persist the retriever to a directory."""
    self.bm25.save(path, corpus=self.corpus, **kwargs)
    with open(os.path.join(path, DEFAULT_PERSIST_FILENAME), "w") as f:
        json.dump(self.get_persist_args(), f, indent=2)
代码解析
  1. 功能

    • 将检索器持久化到指定目录。
  2. 实现

    • 调用bm25对象的save方法,将BM25对象和语料库保存到指定路径。
    • 打开指定路径下的文件,使用json.dump方法将持久化参数字典写入文件。

from_persist_dir类方法

@classmethod
def from_persist_dir(cls, path: str, **kwargs: Any) -> "BM25Retriever":
    """Load the retriever from a directory."""
    bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
    with open(os.path.join(path, DEFAULT_PERSIST_FILENAME)) as f:
        retriever_data = json.load(f)
    return cls(existing_bm25=bm25, **retriever_data)
代码解析
  1. 功能

    • 从指定目录加载检索器。
  2. 实现

    • 调用bm25s.BM25.load方法,从指定路径加载BM25对象和语料库。
    • 打开指定路径下的文件,使用json.load方法读取持久化参数字典。
    • 使用加载的BM25对象和参数字典,调用类的初始化方法创建BM25Retriever实例。

_retrieve方法

def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
    query = query_bundle.query_str
    tokenized_query = bm25s.tokenize(
        query, stemmer=self.stemmer, show_progress=self._verbose
    )
    indexes, scores = self.bm25.retrieve(
        tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
    )

    # batched, but only one query
    indexes = indexes[0]
    scores = scores[0]

    nodes: List[NodeWithScore] = []
    for idx, score in zip(indexes, scores):
        # idx can be an int or a dict of the node
        if isinstance(idx, dict):
            node = metadata_dict_to_node(idx)
        else:
            node_dict = self.corpus[int(idx)]
            node = metadata_dict_to_node(node_dict)
        nodes.append(NodeWithScore(node=node, score=float(score)))

    return nodes
代码解析
  1. 功能

    • 根据查询字符串检索相关节点。
  2. 实现

    • query_bundle中获取查询字符串。
    • 使用bm25s.tokenize方法对查询字符串进行分词处理。
    • 调用bm25对象的retrieve方法,获取检索结果的索引和分数。
    • 处理检索结果,将索引和分数转换为NodeWithScore对象列表。
    • 返回检索结果列表。

示例代码

假设我们已经有一个持久化的检索器,并希望从持久化目录中加载并进行查询:

from some_module import BM25Retriever, QueryBundle

# 从持久化目录加载检索器
retriever = BM25Retriever.from_persist_dir(path="path_to_persist_dir")

# 进行查询
query_bundle = QueryBundle(query_str="sample document")
results = retriever._retrieve(query_bundle)

# 输出结果
for result in results:
    print(result)

代码解释

  1. 加载检索器

    • 使用from_persist_dir类方法从指定路径加载持久化的检索器。
  2. 进行查询

    • 创建QueryBundle对象,包含查询字符串“sample document”。
    • 调用_retrieve方法进行查询,获取检索结果。
  3. 输出结果

    • 遍历检索结果并打印。

总结

通过本文的详细解析,我们深入理解了BM25Retriever类的持久化与检索方法。这些方法提供了高效的数据存储和查询功能,使得系统在处理大量数据时更加稳定和高效。通过提供必要的代码示例和解释,帮助程序员快速掌握并应用这一高效的检索技术。希望本文能为您的编程实践提供有益的参考和指导。

import json
import logging
import os

from typing import Any, Callable, Dict, List, Optional, cast

from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.core.indices.vector_store.base import VectorStoreIndex
from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
from llama_index.core.storage.docstore.types import BaseDocumentStore
from llama_index.core.vector_stores.utils import (
    node_to_metadata_dict,
    metadata_dict_to_node,
)

import bm25s
import Stemmer


logger = logging.getLogger(__name__)

DEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}

DEFAULT_PERSIST_FILENAME = "retriever.json"


class BM25Retriever(BaseRetriever):
    """A BM25 retriever that uses the BM25 algorithm to retrieve nodes.

    Args:
        nodes (List[BaseNode], optional):
            The nodes to index. If not provided, an existing BM25 object must be passed.
        stemmer (Stemmer.Stemmer, optional):
            The stemmer to use. Defaults to an english stemmer.
        language (str, optional):
            The language to use for stopword removal. Defaults to "en".
        existing_bm25 (bm25s.BM25, optional):
            An existing BM25 object to use. If not provided, nodes must be passed.
        similarity_top_k (int, optional):
            The number of results to return. Defaults to DEFAULT_SIMILARITY_TOP_K.
        callback_manager (CallbackManager, optional):
            The callback manager to use. Defaults to None.
        objects (List[IndexNode], optional):
            The objects to retrieve. Defaults to None.
        object_map (dict, optional):
            A map of object IDs to nodes. Defaults to None.
        verbose (bool, optional):
            Whether to show progress. Defaults to False.
    """

    def __init__(
        self,
        nodes: Optional[List[BaseNode]] = None,
        stemmer: Optional[Stemmer.Stemmer] = None,
        language: str = "en",
        existing_bm25: Optional[bm25s.BM25] = None,
        similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
        callback_manager: Optional[CallbackManager] = None,
        objects: Optional[List[IndexNode]] = None,
        object_map: Optional[dict] = None,
        verbose: bool = False,
    ) -> None:
        self.stemmer = stemmer or Stemmer.Stemmer("english")
        self.similarity_top_k = similarity_top_k

        if existing_bm25 is not None:
            self.bm25 = existing_bm25
            self.corpus = existing_bm25.corpus
        else:
            if nodes is None:
                raise ValueError("Please pass nodes or an existing BM25 object.")

            self.corpus = [node_to_metadata_dict(node) for node in nodes]

            corpus_tokens = bm25s.tokenize(
                [node.get_content() for node in nodes],
                stopwords=language,
                stemmer=self.stemmer,
                show_progress=verbose,
            )
            self.bm25 = bm25s.BM25()
            self.bm25.index(corpus_tokens, show_progress=verbose)
        super().__init__(
            callback_manager=callback_manager,
            object_map=object_map,
            objects=objects,
            verbose=verbose,
        )

    @classmethod
    def from_defaults(
        cls,
        index: Optional[VectorStoreIndex] = None,
        nodes: Optional[List[BaseNode]] = None,
        docstore: Optional[BaseDocumentStore] = None,
        stemmer: Optional[Stemmer.Stemmer] = None,
        language: str = "en",
        similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
        verbose: bool = False,
        # deprecated
        tokenizer: Optional[Callable[[str], List[str]]] = None,
    ) -> "BM25Retriever":
        if tokenizer is not None:
            logger.warning(
                "The tokenizer parameter is deprecated and will be removed in a future release. "
                "Use a stemmer from PyStemmer instead."
            )

        # ensure only one of index, nodes, or docstore is passed
        if sum(bool(val) for val in [index, nodes, docstore]) != 1:
            raise ValueError("Please pass exactly one of index, nodes, or docstore.")

        if index is not None:
            docstore = index.docstore

        if docstore is not None:
            nodes = cast(List[BaseNode], list(docstore.docs.values()))

        assert (
            nodes is not None
        ), "Please pass exactly one of index, nodes, or docstore."

        return cls(
            nodes=nodes,
            stemmer=stemmer,
            language=language,
            similarity_top_k=similarity_top_k,
            verbose=verbose,
        )

    def get_persist_args(self) -> Dict[str, Any]:
        """Get Persist Args Dict to Save."""
        return {
            DEFAULT_PERSIST_ARGS[key]: getattr(self, key)
            for key in DEFAULT_PERSIST_ARGS
            if hasattr(self, key)
        }

    def persist(self, path: str, **kwargs: Any) -> None:
        """Persist the retriever to a directory."""
        self.bm25.save(path, corpus=self.corpus, **kwargs)
        with open(os.path.join(path, DEFAULT_PERSIST_FILENAME), "w") as f:
            json.dump(self.get_persist_args(), f, indent=2)

    @classmethod
    def from_persist_dir(cls, path: str, **kwargs: Any) -> "BM25Retriever":
        """Load the retriever from a directory."""
        bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
        with open(os.path.join(path, DEFAULT_PERSIST_FILENAME)) as f:
            retriever_data = json.load(f)
        return cls(existing_bm25=bm25, **retriever_data)

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        query = query_bundle.query_str
        tokenized_query = bm25s.tokenize(
            query, stemmer=self.stemmer, show_progress=self._verbose
        )
        indexes, scores = self.bm25.retrieve(
            tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
        )

        # batched, but only one query
        indexes = indexes[0]
        scores = scores[0]

        nodes: List[NodeWithScore] = []
        for idx, score in zip(indexes, scores):
            # idx can be an int or a dict of the node
            if isinstance(idx, dict):
                node = metadata_dict_to_node(idx)
            else:
                node_dict = self.corpus[int(idx)]
                node = metadata_dict_to_node(node_dict)
            nodes.append(NodeWithScore(node=node, score=float(score)))

        return nodes

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

需要重新演唱

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值