GraphRAG 中 TextUnit 切分的源码解读
一、引言
在 GraphRAG 系统里,文本单元(TextUnit)的切分是一项关键操作,它有助于对大规模文本数据进行更细致的处理和分析。本文将深入剖析 GraphRAG 中与 TextUnit 切分相关的源码,详细介绍其核心逻辑与实现方式。
二、主要类和函数概述
2.1 TextSplitter
抽象基类
TextSplitter
作为抽象基类,为文本切分定义了基本接口,其主要属性和方法如下:
class TextSplitter(ABC):
def __init__(
self,
chunk_size: int = 8191,
chunk_overlap: int = 100,
length_function: LengthFn = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
):
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
self._add_start_index = add_start_index
self._strip_whitespace = strip_whitespace
@abstractmethod
def split_text(self, text: str | list[str]) -> Iterable[str]:
pass
__init__
方法:用于初始化切分参数,像chunk_size
(每个文本块的最大大小)、chunk_overlap
(文本块之间的重叠部分)等。split_text
方法:这是一个抽象方法,具体的切分逻辑由子类实现。
2.2 TokenTextSplitter
类
TokenTextSplitter
继承自 TextSplitter
,它基于令牌(token)对文本进行切分。
class TokenTextSplitter(TextSplitter):
def __init__(
self,
encoding_name: str = defs.ENCODING_MODEL,
model_name: str | None = None,
allowed_special: Literal["all"] | set[str] | None = None,
disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any,
):
super().__init__(**kwargs)
if model_name is not None:
try:
enc = tiktoken.encoding_for_model(model_name)
except KeyError:
log.exception("Model %s not found, using %s", model_name, encoding_name)
enc = tiktoken.get_encoding(encoding_name)
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special or set()
self._disallowed_special = disallowed_special
def encode(self, text: str) -> list[int]:
return self._tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
def num_tokens(self, text: str) -> int:
return len(self.encode(text))
def split_text(self, text: str | list[str]) -> list[str]:
if isinstance(text, list):
text = " ".join(text)
elif cast("bool", pd.isna(text)) or text == "":
return []
if not isinstance(text, str):
msg = f"Attempting to split a non-string value, actual is {type(text)}"
raise TypeError(msg)
tokenizer = Tokenizer(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=lambda text: self.encode(text),
)
return split_single_text_on_tokens(text=text, tokenizer=tokenizer)
__init__
方法:初始化令牌编码器,若指定的模型名称不存在,会使用默认的编码名称。encode
方法:将文本编码为整数向量。num_tokens
方法:返回文本中的令牌数量。split_text
方法:将文本按令牌进行切分,调用split_single_text_on_tokens
函数完成具体的切分操作。
2.3 split_single_text_on_tokens
函数
def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
result = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = tokenizer.decode(list(chunk_ids))
result.append(chunk_text)
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return result
此函数将单个文本按令牌切分为多个文本块。具体步骤如下:
- 对输入文本进行编码,得到整数向量
input_ids
。 - 从起始位置开始,按照
tokens_per_chunk
的大小提取令牌,形成一个文本块。 - 对提取的令牌进行解码,得到文本块的内容,并添加到结果列表中。
- 移动起始位置,考虑重叠部分,继续提取下一个文本块,直到处理完所有令牌。
2.4 split_multiple_texts_on_tokens
函数
def split_multiple_texts_on_tokens(
texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker
) -> list[TextChunk]:
result = []
mapped_ids = []
for source_doc_idx, text in enumerate(texts):
encoded = tokenizer.encode(text)
if tick:
tick(1)
mapped_ids.append((source_doc_idx, encoded))
input_ids = [
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
]
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return result
该函数用于处理多个文本的切分,会返回带有元数据的文本块列表。具体步骤如下:
- 对每个文本进行编码,记录其来源文档的索引。
- 将所有编码后的令牌合并为一个列表
input_ids
,并记录每个令牌的来源文档索引。 - 按照
tokens_per_chunk
的大小提取令牌,形成文本块。 - 对提取的令牌进行解码,得到文本块的内容,并记录该文本块涉及的文档索引。
- 创建
TextChunk
对象,包含文本块内容、涉及的文档索引和令牌数量,添加到结果列表中。 - 移动起始位置,考虑重叠部分,继续提取下一个文本块,直到处理完所有令牌。
三、文本切分的工作流程
3.1 配置切分参数
在 TextSplitter
及其子类的初始化过程中,会设置切分参数,如 chunk_size
和 chunk_overlap
。
3.2 选择切分策略
TokenTextSplitter
采用基于令牌的切分策略,使用 tiktoken
库进行编码和解码操作。
3.3 执行切分操作
根据输入的文本类型(单个文本或多个文本),调用相应的切分函数(split_single_text_on_tokens
或 split_multiple_texts_on_tokens
)进行切分。
3.4 处理切分结果
切分后的文本块会被存储在列表中,可用于后续的处理和分析。
四、总结
GraphRAG 中的 TextUnit 切分功能主要由 TextSplitter
抽象基类和 TokenTextSplitter
子类实现,通过 split_single_text_on_tokens
和 split_multiple_texts_on_tokens
函数完成具体的切分操作。这种基于令牌的切分方式能够更精确地控制文本块的大小,适用于处理大规模的文本数据。同时,代码中还考虑了进度跟踪和错误处理,提高了系统的健壮性和可维护性。