微软GraphRAG :TextUnit 切分的源码解读

GraphRAG 中 TextUnit 切分的源码解读

一、引言

在 GraphRAG 系统里,文本单元(TextUnit)的切分是一项关键操作,它有助于对大规模文本数据进行更细致的处理和分析。本文将深入剖析 GraphRAG 中与 TextUnit 切分相关的源码,详细介绍其核心逻辑与实现方式。

二、主要类和函数概述

2.1 TextSplitter 抽象基类

TextSplitter 作为抽象基类,为文本切分定义了基本接口,其主要属性和方法如下:

class TextSplitter(ABC):
    def __init__(
        self,
        chunk_size: int = 8191,
        chunk_overlap: int = 100,
        length_function: LengthFn = len,
        keep_separator: bool = False,
        add_start_index: bool = False,
        strip_whitespace: bool = True,
    ):
        self._chunk_size = chunk_size
        self._chunk_overlap = chunk_overlap
        self._length_function = length_function
        self._keep_separator = keep_separator
        self._add_start_index = add_start_index
        self._strip_whitespace = strip_whitespace

    @abstractmethod
    def split_text(self, text: str | list[str]) -> Iterable[str]:
        pass
  • __init__ 方法:用于初始化切分参数,像 chunk_size(每个文本块的最大大小)、chunk_overlap(文本块之间的重叠部分)等。
  • split_text 方法:这是一个抽象方法,具体的切分逻辑由子类实现。

2.2 TokenTextSplitter

TokenTextSplitter 继承自 TextSplitter,它基于令牌(token)对文本进行切分。

class TokenTextSplitter(TextSplitter):
    def __init__(
        self,
        encoding_name: str = defs.ENCODING_MODEL,
        model_name: str | None = None,
        allowed_special: Literal["all"] | set[str] | None = None,
        disallowed_special: Literal["all"] | Collection[str] = "all",
        **kwargs: Any,
    ):
        super().__init__(**kwargs)
        if model_name is not None:
            try:
                enc = tiktoken.encoding_for_model(model_name)
            except KeyError:
                log.exception("Model %s not found, using %s", model_name, encoding_name)
                enc = tiktoken.get_encoding(encoding_name)
        else:
            enc = tiktoken.get_encoding(encoding_name)
        self._tokenizer = enc
        self._allowed_special = allowed_special or set()
        self._disallowed_special = disallowed_special

    def encode(self, text: str) -> list[int]:
        return self._tokenizer.encode(
            text,
            allowed_special=self._allowed_special,
            disallowed_special=self._disallowed_special,
        )

    def num_tokens(self, text: str) -> int:
        return len(self.encode(text))

    def split_text(self, text: str | list[str]) -> list[str]:
        if isinstance(text, list):
            text = " ".join(text)
        elif cast("bool", pd.isna(text)) or text == "":
            return []
        if not isinstance(text, str):
            msg = f"Attempting to split a non-string value, actual is {type(text)}"
            raise TypeError(msg)

        tokenizer = Tokenizer(
            chunk_overlap=self._chunk_overlap,
            tokens_per_chunk=self._chunk_size,
            decode=self._tokenizer.decode,
            encode=lambda text: self.encode(text),
        )

        return split_single_text_on_tokens(text=text, tokenizer=tokenizer)
  • __init__ 方法:初始化令牌编码器,若指定的模型名称不存在,会使用默认的编码名称。
  • encode 方法:将文本编码为整数向量。
  • num_tokens 方法:返回文本中的令牌数量。
  • split_text 方法:将文本按令牌进行切分,调用 split_single_text_on_tokens 函数完成具体的切分操作。

2.3 split_single_text_on_tokens 函数

def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
    result = []
    input_ids = tokenizer.encode(text)

    start_idx = 0
    cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
    chunk_ids = input_ids[start_idx:cur_idx]

    while start_idx < len(input_ids):
        chunk_text = tokenizer.decode(list(chunk_ids))
        result.append(chunk_text)
        start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
        cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
        chunk_ids = input_ids[start_idx:cur_idx]

    return result

此函数将单个文本按令牌切分为多个文本块。具体步骤如下:

  1. 对输入文本进行编码,得到整数向量 input_ids
  2. 从起始位置开始,按照 tokens_per_chunk 的大小提取令牌,形成一个文本块。
  3. 对提取的令牌进行解码,得到文本块的内容,并添加到结果列表中。
  4. 移动起始位置,考虑重叠部分,继续提取下一个文本块,直到处理完所有令牌。

2.4 split_multiple_texts_on_tokens 函数

def split_multiple_texts_on_tokens(
    texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker
) -> list[TextChunk]:
    result = []
    mapped_ids = []

    for source_doc_idx, text in enumerate(texts):
        encoded = tokenizer.encode(text)
        if tick:
            tick(1)
        mapped_ids.append((source_doc_idx, encoded))

    input_ids = [
        (source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
    ]

    start_idx = 0
    cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
    chunk_ids = input_ids[start_idx:cur_idx]

    while start_idx < len(input_ids):
        chunk_text = tokenizer.decode([id for _, id in chunk_ids])
        doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
        result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
        start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
        cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
        chunk_ids = input_ids[start_idx:cur_idx]

    return result

该函数用于处理多个文本的切分,会返回带有元数据的文本块列表。具体步骤如下:

  1. 对每个文本进行编码,记录其来源文档的索引。
  2. 将所有编码后的令牌合并为一个列表 input_ids,并记录每个令牌的来源文档索引。
  3. 按照 tokens_per_chunk 的大小提取令牌,形成文本块。
  4. 对提取的令牌进行解码,得到文本块的内容,并记录该文本块涉及的文档索引。
  5. 创建 TextChunk 对象,包含文本块内容、涉及的文档索引和令牌数量,添加到结果列表中。
  6. 移动起始位置,考虑重叠部分,继续提取下一个文本块,直到处理完所有令牌。

三、文本切分的工作流程

3.1 配置切分参数

TextSplitter 及其子类的初始化过程中,会设置切分参数,如 chunk_sizechunk_overlap

3.2 选择切分策略

TokenTextSplitter 采用基于令牌的切分策略,使用 tiktoken 库进行编码和解码操作。

3.3 执行切分操作

根据输入的文本类型(单个文本或多个文本),调用相应的切分函数(split_single_text_on_tokenssplit_multiple_texts_on_tokens)进行切分。

3.4 处理切分结果

切分后的文本块会被存储在列表中,可用于后续的处理和分析。

四、总结

GraphRAG 中的 TextUnit 切分功能主要由 TextSplitter 抽象基类和 TokenTextSplitter 子类实现,通过 split_single_text_on_tokenssplit_multiple_texts_on_tokens 函数完成具体的切分操作。这种基于令牌的切分方式能够更精确地控制文本块的大小,适用于处理大规模的文本数据。同时,代码中还考虑了进度跟踪和错误处理,提高了系统的健壮性和可维护性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值