1.概述
https://github.com/gusye1234/nano-graphrag
😭 GraphRAG很强大,但官方的实现阅读或修改起来非常困难。
😊 本项目提供了一个更小、更快、更简洁的 GraphRAG,同时保留了核心功能。
以下是该项目的详细代码注释,作为学习记录和后续修改代码的参考。
2.分模块注释以及分析
为了节省时间只介绍核心模块,即/nano_graphrag中的代码。
2.1 prompt.py
GRAPH_FIELD_SEP = "<SEP>"
PROMPTS = {}
PROMPTS["claim_extraction"],PROMPTS["community_report"],PROMPTS["entity_extraction"],PROMPTS["summarize_entity_descriptions"],
PROMPTS["entiti_continue_extraction"],PROMPTS["entiti_if_loop_extraction"],
PROMPTS["DEFAULT_ENTITY_TYPES"],PROMPTS["DEFAULT_TUPLE_DELIMITER"],PROMPTS["DEFAULT_RECORD_DELIMITER"],PROMPTS["DEFAULT_COMPLETION_DELIMITER"]
PROMPTS["local_rag_response"],PROMPTS["global_reduce_rag_response"],
PROMPTS["fail_response"],PROMPTS["process_tickers"]
使用的prompt是与官方范例相同的内容,这里就不再赘述了。
2.2 _llm.py
函数列表
gpt_4o_complete,
gpt_4o_mini_complete,
openai_embedding,
azure_gpt_4o_complete,
azure_openai_embedding,
azure_gpt_4o_mini_complete
看着很复杂其实相当于只有两个函数,openai_embedding是调用embedding模型。
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
"""
使用OpenAI的API完成文本生成,如果缓存中已存在相同请求则返回缓存结果。
参数:
- model: 使用的OpenAI模型名称。
- prompt: 用户的输入提示文本。
- system_prompt: (可选)系统提示信息,用以引导模型的响应风格或内容。
- history_messages: (可选)历史对话消息列表,用于聊天上下文。
- **kwargs: 其他额外参数,如缓存存储对象。
返回:
- str: API响应中模型生成的文本内容。
"""
openai_async_client = AsyncOpenAI()
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# 如果缓存对象存在,计算当前请求的哈希值,尝试从缓存中获取结果
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
# 如果有缓存对象,将响应结果存入缓存
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
await hashing_kv.index_done_callback()
return response.choices[0].message.content
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
其他四个都源于openai_complete_if_cache,顾名思义,调用LLM回答问题,如果有缓存则优先调用缓存中的结果,避免资源浪费。
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_embedding(texts: list[str]) -> np.ndarray:
openai_async_client = AsyncOpenAI()
response = await openai_async_client.embeddings.create(
model="text-embedding-3-small", input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
记录缓存使用的格式是BaseKVStorage在base里定义。
2.3_utils.py
函数列表
logger
日志记录器,几乎每个模块都有调用
logger = logging.getLogger("nano-graphrag")
convert_response_to_json
将响应字符串转换为JSON格式的数据。通过系统变量convert_response_to_json_func: callable = convert_response_to_json来调用。
# 通过正则化匹配,从字符串中提取出JSON字符串
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string"""
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
if maybe_json_str is not None:
return maybe_json_str.group(0)
else:
return None
# 将响应字符串转换为JSON格式的数据。
def convert_response_to_json(response: str) -> dict:
json_str = locate_json_string_body_from_string(response)
assert json_str is not None, f"Unable to parse JSON from response: {response}"
try:
data = json.loads(json_str)
return data
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON: {json_str}")
raise e from None
truncate_list_by_token_size
根据token大小截断列表数据。
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
"""Truncate a list of data by token size"""
"""
根据token大小截断列表数据。
该函数的目的是确保列表中数据的总token数不超过指定的最大token大小。
当数据的总token数超过最大允许大小时,函数将返回截断后的列表。
参数:
- list_data: list, 需要截断的列表,其中每个元素为一个数据项。
- key: callable, 用于从列表数据项中提取用于计算token大小的字符串的函数。
- max_token_size: int, 允许的最大token大小,用于决定列表数据的截断点。
返回:
- 截断后的列表。如果max_token_size小于等于0,返回空列表。
注意:
- 该函数使用tiktoken对字符串进行编码并计算token数量,请确保在使用前已安装tiktoken库。
- 截断操作基于累计token数量首次超过max_token_size发生的索引位置。
"""
if max_token_size <= 0:
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(encode_string_by_tiktoken(key(data)))
if tokens > max_token_size:
return list_data[:i]
return list_data
compute_mdhash_id,compute_args_hash
计算哈希值的辅助函数,第二个在_llm.py中已经用过了。
write_json,load_json
用来读写json对象的辅助函数
pack_user_ass_to_openai_messages
将用户和助手的对话打包为OpenAI消息格式。
接受一系列字符串参数,成对地将它们包装成交替的用户和助手角色的消息。
这对于将对话历史记录转换为可供OpenAI的API处理的格式特别有用。在_op.py里就是调用给history的。
def pack_user_ass_to_openai_messages(*args: str):
"""
将用户和助手的对话打包为OpenAI消息格式。
该函数接受一系列字符串参数,成对地将它们包装成交替的用户和助手角色的消息。
这对于将对话历史记录转换为可供OpenAI的API处理的格式特别有用。在_op.py里就是调用给history的。
参数:
*args (str): 一个或多个字符串参数,表示用户和助手之间的对话交替发言。
返回:
list: 一个字典列表,每个字典包含两个键值对:
- 'role': 表示消息发送者的角色,根据参数序列中的位置交替为'user'或'assistant'。
- 'content': 发送者发送的消息内容,来自输入参数序列中的对应位置。
"""
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
is_float_regex
判断是否是浮点数
split_string_by_multi_markers
通过多个标记(markers)分割字符串
list_of_list_to_csv
将多维列表转换为CSV格式,用在社区结构部分。具体是_pack_single_community_describe函数。
clean_str
清理字符串,在_op.py的_handle_single_entity_extraction函数中用到。
class EmbeddingFunc
定义一个用于嵌入的函数类。
limit_async_func_call
为异步函数添加最大并发调用次数的限制。
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
"""
为异步函数添加最大并发调用次数的限制。
参数:
- max_size: int, 允许的最大并发调用次数。
- waitting_time: float, 当达到最大并发调用次数时,每次检查间隔的时间(秒),默认为0.0001秒。
返回:
- 返回一个装饰器函数,用于包装需要限制并发调用的异步函数。
"""
def final_decro(func):
"""Not using async.Semaphore to aovid use nest-asyncio"""
__current_size = 0
@wraps(func)
async def wait_func(*args, **kwargs):
nonlocal __current_size
# 如果当前调用数量达到最大值,等待一段时间后再检查
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
result = await func(*args, **kwargs)
__current_size -= 1
return result
return wait_func
return final_decro
wrap_embedding_func_with_attrs
用属性包装一个函数。
该函数返回一个装饰器,可用于为一个函数添加额外的属性参数。这些属性
被用于在函数执行时提供或记录额外的信息,比如函数的来源、类型等。
这个在_llm.py中调用,是给embedding函数添加了embedding_dim=1536, max_token_size=8192属性
2.4 base.py
QueryParam
查询参数类,用于定义查询时的各种配置选项。
class QueryParam:
mode: Literal["local", "global", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
level: int = 2
top_k: int = 20
# naive search
naive_max_token_for_text_unit = 12000
# local search
local_max_token_for_text_unit: int = 4000 # 12000 * 0.33
local_max_token_for_local_context: int = 4800 # 12000 * 0.4
local_max_token_for_community_report: int = 3200 # 12000 * 0.27
local_community_single_one: bool = False
# global search
global_min_community_rating: float = 0
global_max_consider_community: float = 512
global_max_token_for_community_report: int = 16384
global_special_community_map_llm_kwargs: dict = field(
default_factory=lambda: {"response_format": {"type": "json_object"}}
)
TextChunkSchema,SingleCommunitySchema
文本块,社区存储结构
class CommunitySchema
添加了字符串格式以及json格式报告的社区class,在_storage.py中有更详细的定义
class StorageNameSpace
存储命名空间类,用于管理存储操作。有namespace和global_config两个属性。是后面几个类的基础父类。
class BaseVectorStorage
基础向量存储类,继承自 StorageNameSpace。补充了embedding_func和meta_fields属性。用来定义各种向量存储,比如entities_vdb: BaseVectorStorage。
方法:
query: 查询方法,具体函数在_storage.py中,下同。
upsert: 插入或更新方法。
class BaseKVStorage
基础键值存储类,继承自 StorageNameSpace。用来定义各种键值存储,比如
community_reports: BaseKVStorage[CommunitySchema],
text_chunks_db: BaseKVStorage[TextChunkSchema]。
方法:
all_keys: 获取所有键。
get_by_id: 根据 ID 获取数据。
get_by_ids: 根据多个 ID 获取数据。
filter_keys: 筛选出不存在的键。
upsert: 插入或更新数据。
drop: 删除整个存储空间。
class BaseGraphStorage
基础图存储类,继承自 StorageNameSpace。用来定义各种图结构的存储,比知识图谱knwoledge_graph_inst: BaseGraphStorage
方法:
has_node: 判断节点是否存在。
has_edge: 判断边是否存在。
node_degree: 获取节点度数。
edge_degree: 获取边的度数。
get_node: 获取节点信息。
get_edge: 获取边信息。
get_node_edges: 获取节点的所有边。
upsert_node: 插入或更新节点。
upsert_edge: 插入或更新边。
clustering: 进行图聚类。
community_schema: 获取社区结构。
embed_nodes: 对节点进行嵌入。
2.5 _storage.py
里面存储了大量数据操作的函数实现,并且定义了知识图谱存储方式的类别,所以代码很长。
class JsonKVStorage
继承自BaseKVStorage,基本相同,并且包含方法的具体实现代码。
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
# 初始化时,根据全局配置确定工作目录,以便获取文件完整路径
working_dir = self.global_config["working_dir"]
# 根据命名空间生成特定的 JSON 文件名,用于存储键值数据
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
# 加载存储的数据,如果文件不存在或为空,则初始化为空字典
self._data = load_json(self._file_name) or {}
# 打印日志,显示加载的数据条数
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
# 获取所有的键列表
async def all_keys(self) -> list[str]:
return list(self._data.keys())
# 索引操作完成后,将当前数据写入 JSON 文件
async def index_done_callback(self):
write_json(self._data, self._file_name)
# 通过 ID 获取数据
async def get_by_id(self, id):
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
"""
根据ID列表获取数据项。
参数:
ids (list): 需要获取数据的ID列表。
fields (list, 可选): 限制返回数据中的字段。如果未提供,默认为None,将返回完整数据项。
返回:
list: 包含按指定ID列表顺序排列的数据项的列表。如果某些ID未找到数据项,则相应位置为None。
"""
if fields is None:
return [self._data.get(id, None) for id in ids]
return [
(
# 如果数据项存在,并且ID在_data字典中,则构建一个仅包含fields中字段的新字典
{k: v for k, v in self._data[id].items() if k in fields}
# 检查_id是否在数据集中,以避免KeyError
if self._data.get(id, None)
else None
)
for id in ids
]
# 过滤出不在数据存储中的键列表
async def filter_keys(self, data: list[str]) -> set[str]:
return set([s for s in data if s not in self._data])
# 插入或更新数据
async def upsert(self, data: dict[str, dict]):
self._data.update(data)
# 清空当前存储数据
async def drop(self):
self._data = {}
class NanoVectorDBStorage
继承自BaseVectorDBStorage,还是用了类别NanoVectorDB。
class NanoVectorDBStorage(BaseVectorStorage):
# 余弦相似度阈值,决定返回的结果质量
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
# 初始化向量数据库存储文件和嵌入配置
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
# 初始化向量数据库客户端(NanoVectorDB),并设置嵌入维度
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
# 从全局配置中获取查询的相似度阈值,或使用默认值
self.cosine_better_than_threshold = self.global_config.get(
"query_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
"""
插入或更新向量数据。
该方法用于将字典形式的数据插入或更新到向量数据库中。数据首先被转换成适合插入的格式,
然后分批处理,以避免一次性插入过多数据导致的性能问题。之后,使用异步方式计算各批次数据的嵌入向量,
并将这些向量附加到数据条目中,最后调用客户端的插入或更新方法完成操作。
参数:
data: dict[str, dict] - 一个字典,键是数据的唯一标识,值是包含实际数据内容的字典。
返回:
插入或更新操作的结果。
"""
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
# 将数据转换为适合插入的列表,并提取内容
list_data = [
{
"__id__": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
# 将数据按批次进行处理
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
# 异步计算各批次数据的嵌入向量
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
# 将所有批次的嵌入向量合并为一个大数组
embeddings = np.concatenate(embeddings_list)
# 将计算得到的嵌入向量附加到每个数据条目中
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
# 调用客户端的插入或更新方法,完成数据的插入或更新
results = self._client.upsert(datas=list_data)
return results
async def query(self, query: str, top_k=5):
"""
根据提供的查询字符串获取最相关的文档。
此异步方法使用预训练的embedding函数将查询转换为嵌入表示,
然后在嵌入索引中搜索与查询最相似的文档。
参数:
- query: str,用户查询的字符串。
- top_k: int,返回最相关的文档数量,默认为5。
返回:
- 一个列表,包含最相关的文档及其与查询的相似度距离。
"""
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
# 整理结果,添加文档id和距离信息
results = [
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
]
return results
async def index_done_callback(self):
self._client.save()
HNSWVectorStorage(另一种向量存储方式,这里先略过)
class NetworkXStorage
继承自BaseGraphStorage,存储图结构的数据。
可以先参考这个链接学习一些图(graph)知识。
图的一些基本知识(连通图、连通分量、最小生成树等知识的基本介绍)-CSDN博客
这里社区发现算法还是调用的外部库,和微软开源的GraphRAG一致。
class NetworkXStorage(BaseGraphStorage):
# 加载并返回一个NetworkX图,存储格式为graphml。
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
# 将NetworkX图写入graphml文件。
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
@staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
"""
"""返回图的最大连通分量,并以稳定的方式排序节点和边。
参数:
graph (nx.Graph): 输入的 NetworkX 图。
返回:
nx.Graph: 输入图的最大连通分量,以稳定方式排序。
"""
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
"""
"""
确保无向图以相同的关系读取时始终相同。
参数:
graph (nx.Graph): 输入的网络图。
返回:
nx.Graph: 经过稳定处理的网络图。
"""
# 根据输入图的类型初始化一个新的图实例
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
# 对节点进行排序,以确保节点的添加顺序一致
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
# 向新图中添加排序后的节点
fixed_graph.add_nodes_from(sorted_nodes)
# 将边数据存储到列表中以便后续处理
edges = list(graph.edges(data=True))
# 如果图不是有向图,则对边进行排序,以确保边的顺序一致
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
# 定义获取边的键的函数,用于后续边的排序
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
# 对边进行排序
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
# 向新图中添加排序后的边
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
"""
初始化函数,用于加载图数据并初始化相关属性。
该函数首先根据全局配置中的工作目录和实例的命名空间来确定graphml文件的路径。
然后尝试从该路径加载已存在的图数据。如果图数据存在,则使用NetworkXStorage加载,
并记录日志信息包括图的节点数和边数。如果图数据不存在,则初始化一个新的无向图。
最后,初始化两个算法字典,分别用于图的聚类算法和节点嵌入算法。
"""
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
self._clustering_algorithms = {
"leiden": self._leiden_clustering,
}
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
# 将当前存储的图写入到GraphML文件中
async def index_done_callback(self):
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
"""
异步检查图中是否存在指定节点。下面几个函数类似,就不做标注了,都是调用NetWorkX的函数。
该方法主要用于确定图结构中是否包含特定的节点。它通过调用底层图对象的has_node方法,
以高效的方式查询节点是否存在。
参数:
node_id (str): 要检查的节点的唯一标识符。
返回:
bool: 如果图中存在该节点,则返回True,否则返回False。
"""
return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
return self._graph.nodes.get(node_id)
# 获取指定节点的度数。
async def node_degree(self, node_id: str) -> int:
# [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0
return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0
# 计算两个节点的度数之和
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (
self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0
)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
# 根据指定的算法执行聚类操作。
async def clustering(self, algorithm: str):
if algorithm not in self._clustering_algorithms:
raise ValueError(f"Clustering algorithm {algorithm} not supported")
await self._clustering_algorithms[algorithm]()
async def community_schema(self) -> dict[str, SingleCommunitySchema]:
"""
生成社区架构的异步方法。
该方法计算并组织图中所有节点的社区结构信息。它通过分析节点的集群信息,
构建不同层级的社区,并计算社区的各种属性,如层次、节点数、边数等。
Returns:
一个字典,键为社区键,值为SingleCommunitySchema对象,包含每个社区的详细信息。
"""
# 初始化结果字典,为每个社区准备默认属性
results = defaultdict(
lambda: dict(
level=None,
title=None,
edges=set(),
nodes=set(),
chunk_ids=set(),
occurrence=0.0,
sub_communities=[],
)
)
max_num_ids = 0
# 用于存储不同层级的社区
levels = defaultdict(set)
# 遍历图中的所有节点,收集社区信息
for node_id, node_data in self._graph.nodes(data=True):
# 如果节点没有集群信息,则跳过
if "clusters" not in node_data:
continue
clusters = json.loads(node_data["clusters"])
this_node_edges = self._graph.edges(node_id)
# 遍历节点的所有集群信息
for cluster in clusters:
# 提取并更新社区的层级信息
level = cluster["level"]
cluster_key = str(cluster["cluster"])
levels[level].add(cluster_key)
results[cluster_key]["level"] = level
results[cluster_key]["title"] = f"Cluster {cluster_key}"
results[cluster_key]["nodes"].add(node_id)
results[cluster_key]["edges"].update(
[tuple(sorted(e)) for e in this_node_edges]
)
results[cluster_key]["chunk_ids"].update(
node_data["source_id"].split(GRAPH_FIELD_SEP)
)
# 计算最大的chunk_ids数量,用于后续计算出现率
max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))
# 对层级进行排序,以便后续处理
ordered_levels = sorted(levels.keys())
# 构建子社区关系
for i, curr_level in enumerate(ordered_levels[:-1]):
next_level = ordered_levels[i + 1]
this_level_comms = levels[curr_level]
next_level_comms = levels[next_level]
# compute the sub-communities by nodes intersection
# 通过节点交集计算子社区关系
for comm in this_level_comms:
results[comm]["sub_communities"] = [
c
for c in next_level_comms
if results[c]["nodes"].issubset(results[comm]["nodes"])
]
# 处理并标准化结果字典中的数据
for k, v in results.items():
v["edges"] = list(v["edges"])
v["edges"] = [list(e) for e in v["edges"]]
v["nodes"] = list(v["nodes"])
v["chunk_ids"] = list(v["chunk_ids"])
v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
return dict(results)
def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
"""
将聚类数据分配到子图
该方法将给定的聚类数据分配到图中的相应节点。每个节点的聚类信息
以JSON格式存储在节点的'clusters'属性中。
参数:
cluster_data: 字典,包含节点ID和对应的聚类列表。每个聚类是一个字典,
包含节点在不同聚类中的信息。
"""
# 遍历聚类数据字典中的每个节点及其聚类信息
for node_id, clusters in cluster_data.items():
# 将节点的聚类信息转换为JSON格式,并存储在图节点的'clusters'属性中
self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)
async def _leiden_clustering(self):
"""
基于Leiden算法的图聚类异步方法。
本方法使用Leiden算法对图进行聚类,以发现图中的社区结构。
它从当前图的稳定最大连通组件开始,根据全局配置中的参数进行聚类,
并将聚类结果转换为子图。
"""
from graspologic.partition import hierarchical_leiden
# 获取图的稳定最大连通组件
graph = NetworkXStorage.stable_largest_connected_component(self._graph)
# 应用Leiden算法进行层次聚类
community_mapping = hierarchical_leiden(
graph,
max_cluster_size=self.global_config["max_graph_cluster_size"],
random_seed=self.global_config["graph_cluster_seed"],
)
# 准备节点社区字典
node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)
# 准备用于跟踪各级别社区数量的字典
__levels = defaultdict(set)
# 遍历社区映射,构建节点社区字典和级别跟踪字典
for partition in community_mapping:
level_key = partition.level
cluster_id = partition.cluster
node_communities[partition.node].append(
{"level": level_key, "cluster": cluster_id}
)
__levels[level_key].add(cluster_id)
# 转换节点社区字典和级别跟踪字典为最终形式
node_communities = dict(node_communities)
__levels = {k: len(v) for k, v in __levels.items()}
# 记录每级社区的数量信息
logger.info(f"Each level has communities: {dict(__levels)}")
# 将聚类数据转换为子图
self._cluster_data_to_subgraphs(node_communities)
# 根据指定的算法进行节点嵌入。
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
async def _node2vec_embed(self):
"""
异步方法,用于通过node2vec算法嵌入图结构数据。
该方法使用graspologic库的node2vec_embed函数,根据内部图结构和配置参数进行图嵌入。
它首先调用嵌入函数,然后提取嵌入结果中节点的ID,并返回嵌入向量和节点ID列表。
"""
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
2.6 _op.py
一看就知道,最重要的几个函数基本都封装在这里,所以代码量也来到了作者所说的800行级别(实际是1300行)。
┓( ´∀` )┏
函数列表
chunking_by_token_size
用来进行分块。
def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
):
"""
根据token大小对文本进行分块。
该函数用于将给定的文本内容按照指定的token大小限制进行分块,同时保证相邻块之间有重叠。
主要用于处理大文本,使其能够适应如OpenAI的GPT系列模型的输入限制。
参数:
- content: str, 待分块的文本内容。
- overlap_token_size: int, 默认128. 相邻文本块之间的重叠token数。
- max_token_size: int, 默认1024. 每个文本块的最大token数。
- tiktoken_model: str, 默认"gpt-4o". 用于token化和去token化的tiktoken模型。
返回:
- List[Dict[str, Any]], 包含每个文本块的tokens数量、文本内容和块顺序索引的列表。
"""
# 使用指定的tiktoken模型对文本进行token化
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
# 初始化存储分块结果的列表
results = []
# 遍历tokens,根据max_token_size和overlap_token_size进行分块
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
# 根据当前分块的起始位置和最大token数限制,获取分块的tokens
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
)
# 将当前分块的tokens数量、文本内容和块顺序索引添加到结果列表中
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
}
)
return results
extract_entities
内嵌了实体合并的代码,调用了很多_op.py中的函数,这里就不逐个赘述了。
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
global_config: dict,
) -> Union[BaseGraphStorage, None]:
"""
异步函数extract_entities从文本块中提取实体并更新知识图谱。
参数:
chunks (dict[str, TextChunkSchema]): 文本块字典,键为文本块标识,值为包含文本块内容的TextChunkSchema对象。
knwoledge_graph_inst (BaseGraphStorage): 知识图谱实例,用于存储提取的实体和关系。
entity_vdb (BaseVectorStorage): 实体向量数据库实例,用于存储实体的向量表示。
global_config (dict): 全局配置字典,包含模型函数、最大迭代次数等参数。
返回:
Union[BaseGraphStorage, None]: 更新后的知识图谱实例,如果没有提取到任何实体,则返回None。
"""
# 从全局配置中提取模型函数和实体提取最大迭代次数
use_llm_func: callable = global_config["best_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
# 将文本块排序,以便按顺序处理
ordered_chunks = list(chunks.items())
# 准备实体提取的提示模板和上下文基础信息
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
)
continue_prompt = PROMPTS["entiti_continue_extraction"]
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
# 初始化计数器
already_processed = 0
already_entities = 0
already_relations = 0
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
"""
异步函数_process_single_content处理单个文本块的内容,提取实体并更新知识图谱。
参数:
chunk_key_dp (tuple[str, TextChunkSchema]): 文本块的键值对,包含文本块标识和内容。
返回:
dict: 包含从文本块中提取的可能的节点和边的字典。
"""
# 初始化非局部变量,用于跟踪处理进度和统计信息
nonlocal already_processed, already_entities, already_relations
# 解析文本块的键和数据
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
# 构建提示信息并调用模型函数提取实体
hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
final_result = await use_llm_func(hint_prompt)
# 构建对话历史并进行多次迭代以提取更多实体
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
for now_glean_index in range(entity_extract_max_gleaning):
glean_result = await use_llm_func(continue_prompt, history_messages=history)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
final_result += glean_result
# 检查是否继续迭代
if now_glean_index == entity_extract_max_gleaning - 1:
break
if_loop_result: str = await use_llm_func(
if_loop_prompt, history_messages=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
# 解析结果并更新知识图谱
records = split_string_by_multi_markers(
final_result,
[context_base["record_delimiter"], context_base["completion_delimiter"]],
)
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1)
record_attributes = split_string_by_multi_markers(
record, [context_base["tuple_delimiter"]]
)
# 处理实体提取
if_entities = await _handle_single_entity_extraction(
record_attributes, chunk_key
)
if if_entities is not None:
maybe_nodes[if_entities["entity_name"]].append(if_entities)
continue
# 处理关系提取
if_relation = await _handle_single_relationship_extraction(
record_attributes, chunk_key
)
if if_relation is not None:
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
if_relation
)
# 更新处理进度和统计信息
already_processed += 1
already_entities += len(maybe_nodes)
already_relations += len(maybe_edges)
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return dict(maybe_nodes), dict(maybe_edges)
# 并发处理所有文本块
# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
results = await asyncio.gather(
*[_process_single_content(c) for c in ordered_chunks]
)
print() # clear the progress bar
maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
for m_nodes, m_edges in results:
for k, v in m_nodes.items():
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
# it's undirected graph
maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather(
*[
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
await asyncio.gather(
*[
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
for k, v in maybe_edges.items()
]
)
if not len(all_entities_data):
logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None
if entity_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
await entity_vdb.upsert(data_for_vdb)
return knwoledge_graph_inst
generate_community_report
生成社区报告,同样调用了很多其他函数。
async def generate_community_report(
community_report_kv: BaseKVStorage[CommunitySchema],
knwoledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
"""
异步生成社区报告函数
该函数用于异步生成社区报告,根据提供的知识图谱实例和全局配置,
从社区模式中提取数据,并利用语言模型生成社区报告。
参数:
- community_report_kv: 实现了社区模式存储的BaseKVStorage实例。
- knwoledge_graph_inst: 实现了知识图谱存储的BaseGraphStorage实例。
- global_config: 包含生成社区报告所需的各种配置项的字典。
返回值:
无返回值,但会打印处理进度,并将生成的社区报告存储在community_report_kv中。
"""
# 从全局配置中提取语言模型的额外参数、使用的模型函数和字符串到JSON的转换函数
llm_extra_kwargs = global_config["special_community_report_llm_kwargs"]
use_llm_func: callable = global_config["best_model_func"]
use_string_json_convert_func: callable = global_config[
"convert_response_to_json_func"
]
# 加载社区报告的提示模板
community_report_prompt = PROMPTS["community_report"]
# 从知识图谱实例中获取所有社区模式,并初始化已处理社区计数器
communities_schema = await knwoledge_graph_inst.community_schema()
community_keys, community_values = list(communities_schema.keys()), list(
communities_schema.values()
)
already_processed = 0
# 定义异步函数_form_single_community_report,用于生成单个社区的报告
async def _form_single_community_report(
community: SingleCommunitySchema, already_reports: dict[str, CommunitySchema]
):
nonlocal already_processed
# 为当前社区生成描述文本
describe = await _pack_single_community_describe(
knwoledge_graph_inst,
community,
max_token_size=global_config["best_model_max_token_size"],
already_reports=already_reports,
global_config=global_config,
)
# 构建完整的prompt并使用语言模型生成响应
prompt = community_report_prompt.format(input_text=describe)
response = await use_llm_func(prompt, **llm_extra_kwargs)
# 将响应转换为JSON格式,并更新已处理社区计数器和进度打印
data = use_string_json_convert_func(response)
already_processed += 1
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
f"{now_ticks} Processed {already_processed} communities\r",
end="",
flush=True,
)
return data
# 按社区级别排序,并从高到低处理
levels = sorted(set([c["level"] for c in community_values]), reverse=True)
logger.info(f"Generating by levels: {levels}")
community_datas = {}
for level in levels:
# 筛选当前级别的社区,并并行生成报告
this_level_community_keys, this_level_community_values = zip(
*[
(k, v)
for k, v in zip(community_keys, community_values)
if v["level"] == level
]
)
this_level_communities_reports = await asyncio.gather(
*[
_form_single_community_report(c, community_datas)
for c in this_level_community_values
]
)
# 将生成的报告整合到社区数据字典中
community_datas.update(
{
k: {
"report_string": _community_report_json_to_str(r),
"report_json": r,
**v,
}
for k, r, v in zip(
this_level_community_keys,
this_level_communities_reports,
this_level_community_values,
)
}
)
print() # clear the progress bar
await community_report_kv.upsert(community_datas)
local_query
本地查询。
async def local_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
community_reports: BaseKVStorage[CommunitySchema],
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
"""
根据查询和配置,执行本地查询并返回结果。
该函数首先根据查询和一系列存储实例构建本地查询上下文。如果查询参数指示只需要上下文,
则直接返回上下文字符串。如果上下文为空,则返回预定义的查询失败响应。
否则,将根据上下文和查询参数构造系统提示,并使用全局配置中指定的最佳模型函数生成响应。
参数:
query: 查询字符串。
knowledge_graph_inst: 知识图谱存储实例。
entities_vdb: 实体向量存储实例。
community_reports: 社区报告键值存储实例。
text_chunks_db: 文本块键值存储实例。
query_param: 查询参数,包含查询类型和响应类型等信息。
global_config: 全局配置字典,包含模型函数等关键配置。
返回:
根据查询和上下文生成的响应字符串。
"""
# 从全局配置中获取最佳模型函数
use_model_func = global_config["best_model_func"]
# 构建本地查询上下文
context = await _build_local_query_context(
query,
knowledge_graph_inst,
entities_vdb,
community_reports,
text_chunks_db,
query_param,
)
# 如果查询参数指示只需要上下文,则直接返回上下文
if query_param.only_need_context:
return context
# 如果上下文为空,则返回查询失败的响应
if context is None:
return PROMPTS["fail_response"]
# 根据预定义模板构造系统提示,包含上下文数据和查询参数中的响应类型
sys_prompt_temp = PROMPTS["local_rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
# 使用最佳模型函数生成查询响应
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
return response
global_query
全局查询,并没有使用map-reduce的方法,而是直接把社区摘要进行了top-k排序。
async def global_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
community_reports: BaseKVStorage[CommunitySchema],
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
) -> str:
"""
异步执行全局查询。
此函数根据查询参数和配置从知识图谱、实体向量存储、社区报告及文本块数据库中获取相关信息,
并对这些信息进行处理和排序,最终生成一个综合的回答。
参数:
- query: 查询字符串
- knowledge_graph_inst: 知识图谱存储实例
- entities_vdb: 实体向量存储实例
- community_reports: 社区报告存储实例,存储类型为BaseKVStorage
- text_chunks_db: 文本块存储实例,存储类型为BaseKVStorage
- query_param: 查询参数对象,包含查询级别、最大考虑社区数等参数
- global_config: 全局配置字典,包含最佳模型函数等关键配置
返回:
- str: 查询的最终回答字符串
"""
# 获取并筛选社区schema
community_schema = await knowledge_graph_inst.community_schema()
community_schema = {
k: v for k, v in community_schema.items() if v["level"] <= query_param.level
}
if not len(community_schema):
return PROMPTS["fail_response"]
use_model_func = global_config["best_model_func"]
# 对社区schema进行排序
sorted_community_schemas = sorted(
community_schema.items(),
key=lambda x: x[1]["occurrence"],
reverse=True,
)
sorted_community_schemas = sorted_community_schemas[
: query_param.global_max_consider_community
]
community_datas = await community_reports.get_by_ids(
[k[0] for k in sorted_community_schemas]
)
community_datas = [c for c in community_datas if c is not None]
community_datas = [
c
for c in community_datas
if c["report_json"].get("rating", 0) >= query_param.global_min_community_rating
]
community_datas = sorted(
community_datas,
key=lambda x: (x["occurrence"], x["report_json"].get("rating", 0)),
reverse=True,
)
logger.info(f"Revtrieved {len(community_datas)} communities")
# 映射社区数据并聚合支持点
map_communities_points = await _map_global_communities(
query, community_datas, query_param, global_config
)
final_support_points = []
for i, mc in enumerate(map_communities_points):
for point in mc:
if "description" not in point:
continue
final_support_points.append(
{
"analyst": i,
"answer": point["description"],
"score": point.get("score", 1),
}
)
final_support_points = [p for p in final_support_points if p["score"] > 0]
if not len(final_support_points):
return PROMPTS["fail_response"]
final_support_points = sorted(
final_support_points, key=lambda x: x["score"], reverse=True
)
final_support_points = truncate_list_by_token_size(
final_support_points,
key=lambda x: x["answer"],
max_token_size=query_param.global_max_token_for_community_report,
)
points_context = []
for dp in final_support_points:
points_context.append(
f"""----Analyst {dp['analyst']}----
Importance Score: {dp['score']}
{dp['answer']}
"""
)
points_context = "\n".join(points_context)
if query_param.only_need_context:
return points_context
sys_prompt_temp = PROMPTS["global_reduce_rag_response"]
response = await use_model_func(
query,
sys_prompt_temp.format(
report_data=points_context, response_type=query_param.response_type
),
)
return response
naive_query
朴素查询,不做重点。
2.7 graphrag.py
这回终于可以看主函数了,但有了前面的铺垫,具体代码就比较简单了。不过里面包含了不少增量式的处理方法,只要理解基本思想和功能看代码也会清楚一些。目标是可以实现输入a文档后,再输入b文档补充到a文档生成的知识图谱中,也就是图扩展。以后可以考虑在此基础上做图更新。
函数列表
always_get_an_event_loop()
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
# If there is already an event loop, use it.
loop = asyncio.get_event_loop()
except RuntimeError:
# If in a sub-thread, create a new event loop.
logger.info("Creating a new event loop in a sub-thread.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
class GraphRAG
主要函数类别,实现模块化调用。
insert()通过always_get_an_event_loop()调用异步函数ainsert(),以实现增量添加文本文件,query()和aquery()同理。
ainsert()——分块,实体关系提取,知识图谱聚类,社区报告的函数调用
aquery()——local,global,naive
class GraphRAG:
# 表示工作目录的路径,默认是根据当前日期时间生成的目录
working_dir: str = field(
default_factory=lambda: f"./nano_graphrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)
# graph mode
# 是否本地存储,默认是本地存储
enable_local: bool = True
# 是否启用朴素rag,默认不启用
enable_naive_rag: bool = False
# text chunking
chunk_func: Callable[[str, Optional[int], Optional[int], Optional[str]], List[Dict[str, Union[str, int]]]] = chunking_by_token_size
# 分块大小
chunk_token_size: int = 1200
# 分块重叠数量
chunk_overlap_token_size: int = 100
# tiktoken使用的模型名字,默认为gpt-4o
tiktoken_model_name: str = "gpt-4o"
# entity extraction
# 实体提取最大“拾取”次数,也就是反复提取次数,默认不反复提取
entity_extract_max_gleaning: int = 1
# 实体摘要最大token数
entity_summary_to_max_tokens: int = 500
# graph clustering
# 社区聚类算法,默认为莱顿算法
graph_cluster_algorithm: str = "leiden"
# 最大社区聚类节点数,默认为10
max_graph_cluster_size: int = 10
# 社区聚类随机种子,默认为0xDEADBEEF
graph_cluster_seed: int = 0xDEADBEEF
# node embedding
node_embedding_algorithm: str = "node2vec"
# 如果没有显式传入 node2vec_params,则会调用这个 lambda 函数,自动生成并赋值为这个默认的字典
node2vec_params: dict = field(
default_factory=lambda: {
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"num_walks": 10,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
}
)
# community reports
# 以json格式返回社区报告
special_community_report_llm_kwargs: dict = field(
default_factory=lambda: {"response_format": {"type": "json_object"}}
)
# text embedding
# 默认使用openai的embedding
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
# 批量大小,默认为32
embedding_batch_num: int = 32
# 最大并发请求数量
embedding_func_max_async: int = 16
# 用于比较查询结果质量的阈值
query_better_than_threshold: float = 0.2
# LLM相关调用
# 需要两种类型的大语言模型(LLM),一种是高性能的,用于规划和回应;另一种是成本较低的,用于总结
using_azure_openai: bool = False
best_model_func: callable = gpt_4o_complete
best_model_max_token_size: int = 32768
best_model_max_async: int = 16
cheap_model_func: callable = gpt_4o_mini_complete
cheap_model_max_token_size: int = 32768
cheap_model_max_async: int = 16
# entity extraction
# 实体提取函数
entity_extraction_func: callable = extract_entities
# storage
# 存储类型设置
# 键值存储,json,具体定义在_storage.py
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
# 向量数据库存储,具体定义在_storage.py
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
# 为向量数据库存储类提供可选的参数字典
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
# 图数据库存储,默认为NetworkXStorage
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
# 是否启用llm缓存
enable_llm_cache: bool = True
# extension
# 用于传递额外参数的字典
addon_params: dict = field(default_factory=dict)
# 用于将 LLM 输出转换为 JSON 的函数
convert_response_to_json_func: callable = convert_response_to_json
# 在对象初始化后调用此方法,主要作用为打印配置信息和根据配置进行一些设置调整
def __post_init__(self):
# 将对象的属性以键值对的形式打印出来,用于调试和日志记录
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"GraphRAG init with param:\n\n {_print_config}\n")
# 如果配置了使用Azure OpenAI服务,则调整相关函数为Azure版本
if self.using_azure_openai:
if self.best_model_func == gpt_4o_complete:
self.best_model_func = azure_gpt_4o_complete
if self.cheap_model_func == gpt_4o_mini_complete:
self.cheap_model_func = azure_gpt_4o_mini_complete
if self.embedding_func == openai_embedding:
self.embedding_func = azure_openai_embedding
logger.info(
"Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
)
# 确保工作目录存在,如果不存在则创建
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
# 初始化存储类实例,用于存储完整文档
self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs", global_config=asdict(self)
)
# 初始化存储类实例,用于存储文本块
self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks", global_config=asdict(self)
)
# 根据配置初始化LLM响应缓存,如果配置为启用则创建缓存实例
self.llm_response_cache = (
self.key_string_value_json_storage_cls(
namespace="llm_response_cache", global_config=asdict(self)
)
if self.enable_llm_cache
else None
)
# 初始化存储类实例,用于存储社区报告
self.community_reports = self.key_string_value_json_storage_cls(
namespace="community_reports", global_config=asdict(self)
)
# 初始化图存储类实例,用于存储块实体关系图
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation", global_config=asdict(self)
)
# 限制embedding函数的异步调用次数
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
# 根据配置初始化向量数据库存储类实例,用于存储实体
self.entities_vdb = (
self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
if self.enable_local
else None
)
# 根据配置初始化向量数据库存储类实例,用于存储块
self.chunks_vdb = (
self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
if self.enable_naive_rag
else None
)
# 限制最佳模型函数的异步调用次数,并为其配置哈希键值存储
self.best_model_func = limit_async_func_call(self.best_model_max_async)(
partial(self.best_model_func, hashing_kv=self.llm_response_cache)
)
# 限制廉价模型函数的异步调用次数,并为其配置哈希键值存储
self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)(
partial(self.cheap_model_func, hashing_kv=self.llm_response_cache)
)
def insert(self, string_or_strings):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))
def query(self, query: str, param: QueryParam = QueryParam()):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()):
"""
异步查询函数,根据指定的查询模式执行相应的查询。
该函数支持三种查询模式:本地模式 ("local")、全局模式 ("global") 和朴素模式 ("naive")。
根据提供的参数确定使用哪种查询模式,并在查询完成后执行查询结束的钩子函数。
参数:
- query: str, 要查询的字符串。
- param: QueryParam, 查询参数对象,包含查询模式和其他查询相关的参数。
返回:
- response: 从查询函数返回的响应。
异常:
- ValueError: 当尝试在不支持的模式下执行查询时抛出。
"""
if param.mode == "local" and not self.enable_local:
raise ValueError("enable_local is False, cannot query in local mode")
if param.mode == "naive" and not self.enable_naive_rag:
raise ValueError("enable_naive_rag is False, cannot query in local mode")
# 根据查询模式执行相应的查询函数
if param.mode == "local":
response = await local_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "global":
response = await global_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,
self.community_reports,
self.text_chunks,
param,
asdict(self),
)
elif param.mode == "naive":
response = await naive_query(
query,
self.chunks_vdb,
self.text_chunks,
param,
asdict(self),
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
async def ainsert(self, string_or_strings):
"""
异步插入字符串或字符串列表到文档和片段数据库中,并更新知识图谱和社区报告。
参数:
string_or_strings (str 或 List[str]): 要插入的字符串或字符串列表。
"""
try:
# 如果输入是一个字符串,将其转换为列表
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# ---------- new docs
# 将字符串或字符串列表string_or_strings中的每个元素去除首尾空白后,作为文档内容。
# 计算其MD5哈希值并添加前缀doc-作为键,内容本身作为值,生成一个新的字典new_docs。
new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
}
# 筛选出需要添加的新文档ID
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
# 根据筛选结果更新新文档字典
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
# 如果没有新文档需要添加,记录日志并返回
if not len(new_docs):
logger.warning(f"All docs are already in the storage")
return
# 记录插入新文档的日志
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
# ---------- chunking
inserting_chunks = {}
for doc_key, doc in new_docs.items():
# 为每个文档生成片段
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
}
for dp in self.chunk_func(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
inserting_chunks.update(chunks)
# 筛选出需要添加的新片段ID
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
# 根据筛选结果更新新片段字典
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
# 如果没有新片段需要添加,记录日志并返回
if not len(inserting_chunks):
logger.warning(f"All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
# 如果启用了简单的RAG,则插入片段到相应的数据库
if self.enable_naive_rag:
logger.info("Insert chunks for naive RAG")
await self.chunks_vdb.upsert(inserting_chunks)
# 由于目前不支持增量更新社区,因此删除所有现有的社区报告
# TODO: no incremental update for communities now, so just drop all
await self.community_reports.drop()
# ---------- extract/summary entity and upsert to graph
# 提取新实体和关系,并更新到知识图谱中
logger.info("[Entity Extraction]...")
maybe_new_kg = await self.entity_extraction_func(
inserting_chunks,
knwoledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities found")
return
self.chunk_entity_relation_graph = maybe_new_kg
# ---------- update clusterings of graph
# 更新知识图谱的聚类
logger.info("[Community Report]...")
await self.chunk_entity_relation_graph.clustering(
self.graph_cluster_algorithm
)
# 生成并更新社区报告
await generate_community_report(
self.community_reports, self.chunk_entity_relation_graph, asdict(self)
)
# ---------- commit upsertings and indexing
# 提交所有更新和索引操作
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
await self._insert_done()
async def _insert_done(self):
tasks = []
for storage_inst in [
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.community_reports,
self.entities_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
]:
if storage_inst is None:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
async def _query_done(self):
"""
异步函数,用于在查询完成后,确保所有相关的存储实例的索引更新操作完成。
1. 遍历预定义的存储实例列表(当前只有一个llm_response_cache)。
2. 对于每个非空的存储实例,添加其索引更新完成的回调函数到任务列表。
3. 使用asyncio.gather等待所有添加到任务列表中的回调函数执行完成。
该函数的主要作用是确保在查询完成后,所有配置的存储实例都完成了它们的索引更新操作,
这对于保持数据一致性和完整性至关重要。
"""
tasks = []
for storage_inst in [self.llm_response_cache]:
if storage_inst is None:
continue
# 将存储实例的索引更新完成的回调函数添加到任务列表
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
3.使用测试
这里我采用本地embedding和调用kimi作为大模型来测试结果,根据官方例子编写了如下的代码。
import os
import sys
import logging
from openai import AsyncOpenAI
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
from sentence_transformers import SentenceTransformer
import numpy as np
sys.path.append("..")
logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
API_KEY = "填你自己的"
MODEL = "moonshot-v1-32k"
URL = "https://api.moonshot.cn/v1"
QUESTION = "请概括故事的主要情节并分析故事的主旨。"
WORKING_DIR = "./workspace"
FILEPATH = "./txt/追逐雪的人.txt"
EMBED_MODEL = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
)
async def model_if_cache(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
openai_async_client = AsyncOpenAI(
api_key=API_KEY, base_url=URL
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Get the cached response if having-------------------
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(MODEL, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# -----------------------------------------------------
response = await openai_async_client.chat.completions.create(
model=MODEL, messages=messages, **kwargs
)
# Cache the response if having-------------------
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
)
# -----------------------------------------------------
return response.choices[0].message.content
def remove_if_exist(file):
if os.path.exists(file):
os.remove(file)
def query():
rag = GraphRAG(
working_dir=WORKING_DIR,
best_model_func=model_if_cache,
cheap_model_func=model_if_cache,
embedding_func=local_embedding
)
print(
rag.query(
QUESTION, param=QueryParam(mode="global")
)
)
@wrap_embedding_func_with_attrs(
embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
max_token_size=EMBED_MODEL.max_seq_length,
)
async def local_embedding(texts: list[str]) -> np.ndarray:
return EMBED_MODEL.encode(texts, normalize_embeddings=True)
def insert():
from time import time
with open(FILEPATH, encoding="utf-8-sig") as f:
FAKE_TEXT = f.read()
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
rag = GraphRAG(
working_dir=WORKING_DIR,
enable_llm_cache=True,
best_model_func=model_if_cache,
cheap_model_func=model_if_cache,
embedding_func=local_embedding
)
start = time()
rag.insert(FAKE_TEXT)
print("indexing time:", time() - start)
if __name__ == "__main__":
insert()
query()
回答结果如下
### 故事主要情节概括
故事主要围绕两个核心人物——"她"和"我"——以及他们的宠物"花卷"展开。这两个人物通过共同的活动、
深入的对话和相互关怀来加深彼此的情感联系。他们一起经历了各种活动和情感交流,包括参观博物馆、
讨论艺术作品,以及参与社区中的艺术和教育活动。宠物"花卷"在故事中扮演了情感纽带的角色,增强了
两位主角之间的联系,并且是他们共同关心和照顾的对象。
### 故事主旨分析
1. **人际关系与情感联系**:故事的主旨在于探索人与人之间的深层情感联系,以及在共同经历和相互
支持中成长和发现生活的意义。"她"和"我"之间的关系是故事的核心,他们通过亲密的旅程,强调了在面
对生活中的挑战和困难时,人与人之间的相互支持和关怀的重要性。
2. **社区互动与个人成长**:故事还展现了社区互动和个人成长的重要性。一个男孩在社区中的互动和
成长经历,涉及学校活动、艺术任务和个人关系,表明社区成员之间的互动对个人发展具有重要影响。艺
术和教育活动在促进社区凝聚力和个人发展中发挥了重要作用。
3. **艺术与美的欣赏**:故事中包含了对艺术和美的欣赏,两位主角在参观博物馆和讨论艺术作品时展
现了他们的情感智慧和共同的感性。艺术活动,特别是海报着色任务,是连接社区成员的重要活动,促进
社区参与和艺术表达。
4. **教育与权威角色**:老师在社区中具有双重角色,既是权威人物,又是教育活动的促进者,管理课
堂行为并支持艺术任务的完成。这表明教育者在塑造社区文化和促进个人成长中扮演着关键角色。
5. **导师与钦佩动态**:男孩与俱乐部领袖的关系表明社区内存在强烈的导师或钦佩动态,影响男孩在
社区中的动机和参与度。这种关系对个人成长和社区凝聚力具有积极影响。
综上所述,故事通过"她"、"我"和宠物"花卷"的亲密旅程,以及他们在社区中的互动和成长经历,探讨了
人际关系、社区互动、艺术欣赏、教育角色和导师动态等多重主题,展现了人与人之间的情感联系、个人
成长和社区凝聚力的重要性。
生成的文件均为json格式,知识图谱是graphml格式。
构建的知识图谱如下图。