深入探讨 SemanticSplitterNodeParser:文档语义分割的艺术
在自然语言处理(NLP)领域,如何有效地将文档分割成有意义的片段是一个关键问题。SemanticSplitterNodeParser
是一个强大的工具,它能够根据语义相似性将文档分割成多个节点。本文将深入探讨这个类的实现细节,帮助你理解其工作原理及实际应用。
1. 前置知识
在深入代码之前,我们需要了解一些基础概念:
- 节点(Node):在 NLP 中,节点通常表示文档中的一个片段,例如一个句子或一个段落。
- 嵌入模型(Embedding Model):嵌入模型将文本转换为向量表示,这些向量捕捉了文本的语义信息。
- 语义相似性(Semantic Similarity):语义相似性衡量两个文本片段在语义上的接近程度。
2. 类定义与初始化
SemanticSplitterNodeParser
类的主要功能是将文档分割成语义相关的节点。让我们从类的定义开始:
class SemanticSplitterNodeParser(NodeParser):
"""Semantic node parser.
Splits a document into Nodes, with each node being a group of semantically related sentences.
Args:
buffer_size (int): number of sentences to group together when evaluating semantic similarity
embed_model: (BaseEmbedding): embedding model to use
sentence_splitter (Optional[Callable]): splits text into sentences
include_metadata (bool): whether to include metadata in nodes
include_prev_next_rel (bool): whether to include prev/next relationships
"""
sentence_splitter: SentenceSplitterCallable = Field(
default_factory=split_by_sentence_tokenizer,
description="The text splitter to use when splitting documents.",
exclude=True,
)
embed_model: SerializeAsAny[BaseEmbedding] = Field(
description="The embedding model to use to for semantic comparison",
)
buffer_size: int = Field(
default=1,
description=(
"The number of sentences to group together when evaluating semantic similarity. "
"Set to 1 to consider each sentence individually. "
"Set to >1 to group sentences together."
),
)
breakpoint_percentile_threshold: int = Field(
default=95,
description=(
"The percentile of cosine dissimilarity that must be exceeded between a "
"group of sentences and the next to form a node. The smaller this "
"number is, the more nodes will be generated"
),
)
2.1 参数解释
buffer_size
:在评估语义相似性时,将多少个句子组合在一起。默认值为1,表示每个句子单独评估。embed_model
:用于语义比较的嵌入模型。sentence_splitter
:将文本分割成句子的函数。breakpoint_percentile_threshold
:形成节点的余弦不相似性百分位阈值。较小的值会产生更多的节点。
3. 从默认值创建实例
from_defaults
方法允许我们从默认值创建 SemanticSplitterNodeParser
实例:
@classmethod
def from_defaults(
cls,
embed_model: Optional[BaseEmbedding] = None,
breakpoint_percentile_threshold: Optional[int] = 95,
buffer_size: Optional[int] = 1,
sentence_splitter: Optional[Callable[[str], List[str]]] = None,
original_text_metadata_key: str = DEFAULT_OG_TEXT_METADATA_KEY,
include_metadata: bool = True,
include_prev_next_rel: bool = True,
callback_manager: Optional[CallbackManager] = None,
id_func: Optional[Callable[[int, Document], str]] = None,
) -> "SemanticSplitterNodeParser":
callback_manager = callback_manager or CallbackManager([])
sentence_splitter = sentence_splitter or split_by_sentence_tokenizer()
if embed_model is None:
try:
from llama_index.embeddings.openai import (
OpenAIEmbedding,
) # pants: no-infer-dep
embed_model = embed_model or OpenAIEmbedding()
except ImportError:
raise ImportError(
"`llama-index-embeddings-openai` package not found, "
"please run `pip install llama-index-embeddings-openai`"
)
id_func = id_func or default_id_func
return cls(
embed_model=embed_model,
breakpoint_percentile_threshold=breakpoint_percentile_threshold,
buffer_size=buffer_size,
sentence_splitter=sentence_splitter,
original_text_metadata_key=original_text_metadata_key,
include_metadata=include_metadata,
include_prev_next_rel=include_prev_next_rel,
callback_manager=callback_manager,
id_func=id_func,
)
3.1 代码解释
embed_model
:如果没有提供嵌入模型,默认使用OpenAIEmbedding
。sentence_splitter
:如果没有提供句子分割器,默认使用split_by_sentence_tokenizer
。callback_manager
:用于管理回调函数。
4. 核心方法:build_semantic_nodes_from_documents
build_semantic_nodes_from_documents
方法是 SemanticSplitterNodeParser
的核心,它将文档分割成语义节点:
def build_semantic_nodes_from_documents(
self,
documents: Sequence[Document],
show_progress: bool = False,
) -> List[BaseNode]:
"""Build window nodes from documents."""
all_nodes: List[BaseNode] = []
for doc in documents:
text = doc.text
text_splits = self.sentence_splitter(text)
sentences = self._build_sentence_groups(text_splits)
combined_sentence_embeddings = self.embed_model.get_text_embedding_batch(
[s["combined_sentence"] for s in sentences],
show_progress=show_progress,
)
for i, embedding in enumerate(combined_sentence_embeddings):
sentences[i]["combined_sentence_embedding"] = embedding
distances = self._calculate_distances_between_sentence_groups(sentences)
chunks = self._build_node_chunks(sentences, distances)
nodes = build_nodes_from_splits(
chunks,
doc,
id_func=self.id_func,
)
all_nodes.extend(nodes)
return all_nodes
4.1 代码解释
text_splits
:将文档文本分割成句子。sentences
:将句子分组,并计算每个组的嵌入向量。distances
:计算句子组之间的语义距离。chunks
:根据语义距离将句子组分割成节点。nodes
:将分割后的节点转换为BaseNode
对象。
5. 辅助方法
5.1 _build_sentence_groups
def _build_sentence_groups(
self, text_splits: List[str]
) -> List[SentenceCombination]:
sentences: List[SentenceCombination] = [
{
"sentence": x,
"index": i,
"combined_sentence": "",
"combined_sentence_embedding": [],
}
for i, x in enumerate(text_splits)
]
# Group sentences and calculate embeddings for sentence groups
for i in range(len(sentences)):
combined_sentence = ""
for j in range(i - self.buffer_size, i):
if j >= 0:
combined_sentence += sentences[j]["sentence"]
combined_sentence += sentences[i]["sentence"]
for j in range(i + 1, i + 1 + self.buffer_size):
if j < len(sentences):
combined_sentence += sentences[j]["sentence"]
sentences[i]["combined_sentence"] = combined_sentence
return sentences
5.2 _calculate_distances_between_sentence_groups
def _calculate_distances_between_sentence_groups(
self, sentences: List[SentenceCombination]
) -> List[float]:
distances = []
for i in range(len(sentences) - 1):
embedding_current = sentences[i]["combined_sentence_embedding"]
embedding_next = sentences[i + 1]["combined_sentence_embedding"]
similarity = self.embed_model.similarity(embedding_current, embedding_next)
distance = 1 - similarity
distances.append(distance)
return distances
5.3 _build_node_chunks
def _build_node_chunks(
self, sentences: List[SentenceCombination], distances: List[float]
) -> List[str]:
chunks = []
if len(distances) > 0:
breakpoint_distance_threshold = np.percentile(
distances, self.breakpoint_percentile_threshold
)
indices_above_threshold = [
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
]
# Chunk sentences into semantic groups based on percentile breakpoints
start_index = 0
for index in indices_above_threshold:
group = sentences[start_index : index + 1]
combined_text = "".join([d["sentence"] for d in group])
chunks.append(combined_text)
start_index = index + 1
if start_index < len(sentences):
combined_text = "".join(
[d["sentence"] for d in sentences[start_index:]]
)
chunks.append(combined_text)
else:
# If, for some reason we didn't get any distances (i.e. very, very small documents) just
# treat the whole document as a single node
chunks = [" ".join([s["sentence"] for s in sentences])]
return chunks
6. 总结
SemanticSplitterNodeParser
是一个强大的工具,它能够根据语义相似性将文档分割成多个节点。通过理解其工作原理,你可以更好地应用它来处理复杂的 NLP 任务。希望本文能帮助你深入理解这个类的实现细节,并在实际项目中发挥作用。
如果你有任何问题或需要进一步的帮助,请随时在评论区留言!