nano-graphrag代码详解

1.概述

https://github.com/gusye1234/nano-graphrag

😭 GraphRAG很强大,但官方的实现阅读或修改起来非常困难。

😊 本项目提供了一个更小、更快、更简洁的 GraphRAG,同时保留了核心功能。

以下是该项目的详细代码注释,作为学习记录和后续修改代码的参考。

2.分模块注释以及分析

为了节省时间只介绍核心模块,即/nano_graphrag中的代码。

2.1 prompt.py

GRAPH_FIELD_SEP = "<SEP>"

PROMPTS = {}

PROMPTS["claim_extraction"],PROMPTS["community_report"],PROMPTS["entity_extraction"],PROMPTS["summarize_entity_descriptions"],

PROMPTS["entiti_continue_extraction"],PROMPTS["entiti_if_loop_extraction"],

PROMPTS["DEFAULT_ENTITY_TYPES"],PROMPTS["DEFAULT_TUPLE_DELIMITER"],PROMPTS["DEFAULT_RECORD_DELIMITER"],PROMPTS["DEFAULT_COMPLETION_DELIMITER"]

PROMPTS["local_rag_response"],PROMPTS["global_reduce_rag_response"],

PROMPTS["fail_response"],PROMPTS["process_tickers"]

使用的prompt是与官方范例相同的内容,这里就不再赘述了。

2.2 _llm.py

函数列表

gpt_4o_complete,

gpt_4o_mini_complete,

openai_embedding,

azure_gpt_4o_complete,

azure_openai_embedding,

azure_gpt_4o_mini_complete

看着很复杂其实相当于只有两个函数,openai_embedding是调用embedding模型。

@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_complete_if_cache(
    model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
    """
    使用OpenAI的API完成文本生成,如果缓存中已存在相同请求则返回缓存结果。
    
    参数:
    - model: 使用的OpenAI模型名称。
    - prompt: 用户的输入提示文本。
    - system_prompt: (可选)系统提示信息,用以引导模型的响应风格或内容。
    - history_messages: (可选)历史对话消息列表,用于聊天上下文。
    - **kwargs: 其他额外参数,如缓存存储对象。

    返回:
    - str: API响应中模型生成的文本内容。
    """
    openai_async_client = AsyncOpenAI()
    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.extend(history_messages)
    messages.append({"role": "user", "content": prompt})
    # 如果缓存对象存在,计算当前请求的哈希值,尝试从缓存中获取结果
    if hashing_kv is not None:
        args_hash = compute_args_hash(model, messages)
        if_cache_return = await hashing_kv.get_by_id(args_hash)
        if if_cache_return is not None:
            return if_cache_return["return"]

    response = await openai_async_client.chat.completions.create(
        model=model, messages=messages, **kwargs
    )
    # 如果有缓存对象,将响应结果存入缓存
    if hashing_kv is not None:
        await hashing_kv.upsert(
            {args_hash: {"return": response.choices[0].message.content, "model": model}}
        )
        await hashing_kv.index_done_callback()
    return response.choices[0].message.content


async def gpt_4o_complete(
    prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
    return await openai_complete_if_cache(
        "gpt-4o",
        prompt,
        system_prompt=system_prompt,
        history_messages=history_messages,
        **kwargs,
    )

其他四个都源于openai_complete_if_cache,顾名思义,调用LLM回答问题,如果有缓存则优先调用缓存中的结果,避免资源浪费。

@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_embedding(texts: list[str]) -> np.ndarray:
    openai_async_client = AsyncOpenAI()
    response = await openai_async_client.embeddings.create(
        model="text-embedding-3-small", input=texts, encoding_format="float"
    )
    return np.array([dp.embedding for dp in response.data])

记录缓存使用的格式是BaseKVStorage在base里定义。

2.3_utils.py

函数列表

logger

日志记录器,几乎每个模块都有调用

logger = logging.getLogger("nano-graphrag")

convert_response_to_json

将响应字符串转换为JSON格式的数据。通过系统变量convert_response_to_json_func: callable = convert_response_to_json来调用。

# 通过正则化匹配,从字符串中提取出JSON字符串
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
    """Locate the JSON string body from a string"""
    maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
    if maybe_json_str is not None:
        return maybe_json_str.group(0)
    else:
        return None

# 将响应字符串转换为JSON格式的数据。
def convert_response_to_json(response: str) -> dict:
    json_str = locate_json_string_body_from_string(response)
    assert json_str is not None, f"Unable to parse JSON from response: {response}"
    try:
        data = json.loads(json_str)
        return data
    except json.JSONDecodeError as e:
        logger.error(f"Failed to parse JSON: {json_str}")
        raise e from None

truncate_list_by_token_size

根据token大小截断列表数据。

def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
    """Truncate a list of data by token size"""
    """
    根据token大小截断列表数据。

    该函数的目的是确保列表中数据的总token数不超过指定的最大token大小。
    当数据的总token数超过最大允许大小时,函数将返回截断后的列表。

    参数:
    - list_data: list, 需要截断的列表,其中每个元素为一个数据项。
    - key: callable, 用于从列表数据项中提取用于计算token大小的字符串的函数。
    - max_token_size: int, 允许的最大token大小,用于决定列表数据的截断点。

    返回:
    - 截断后的列表。如果max_token_size小于等于0,返回空列表。

    注意:
    - 该函数使用tiktoken对字符串进行编码并计算token数量,请确保在使用前已安装tiktoken库。
    - 截断操作基于累计token数量首次超过max_token_size发生的索引位置。
    """
    if max_token_size <= 0:
        return []
    tokens = 0
    for i, data in enumerate(list_data):
        tokens += len(encode_string_by_tiktoken(key(data)))
        if tokens > max_token_size:
            return list_data[:i]
    return list_data

compute_mdhash_id,compute_args_hash

计算哈希值的辅助函数,第二个在_llm.py中已经用过了。

write_json,load_json

用来读写json对象的辅助函数

pack_user_ass_to_openai_messages

将用户和助手的对话打包为OpenAI消息格式。

接受一系列字符串参数,成对地将它们包装成交替的用户和助手角色的消息。

这对于将对话历史记录转换为可供OpenAI的API处理的格式特别有用。在_op.py里就是调用给history的。

def pack_user_ass_to_openai_messages(*args: str):
    """
    将用户和助手的对话打包为OpenAI消息格式。

    该函数接受一系列字符串参数,成对地将它们包装成交替的用户和助手角色的消息。
    这对于将对话历史记录转换为可供OpenAI的API处理的格式特别有用。在_op.py里就是调用给history的。

    参数:
    *args (str): 一个或多个字符串参数,表示用户和助手之间的对话交替发言。

    返回:
    list: 一个字典列表,每个字典包含两个键值对:
        - 'role': 表示消息发送者的角色,根据参数序列中的位置交替为'user'或'assistant'。
        - 'content': 发送者发送的消息内容,来自输入参数序列中的对应位置。
    """
    roles = ["user", "assistant"]
    return [
        {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
    ]

is_float_regex

判断是否是浮点数

split_string_by_multi_markers

通过多个标记(markers)分割字符串

list_of_list_to_csv

将多维列表转换为CSV格式,用在社区结构部分。具体是_pack_single_community_describe函数。

clean_str

清理字符串,在_op.py的_handle_single_entity_extraction函数中用到。

class EmbeddingFunc

定义一个用于嵌入的函数类。

limit_async_func_call

为异步函数添加最大并发调用次数的限制。

def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
    """Add restriction of maximum async calling times for a async func"""
    """
    为异步函数添加最大并发调用次数的限制。

    参数:
    - max_size: int, 允许的最大并发调用次数。
    - waitting_time: float, 当达到最大并发调用次数时,每次检查间隔的时间(秒),默认为0.0001秒。

    返回:
    - 返回一个装饰器函数,用于包装需要限制并发调用的异步函数。
    """
    def final_decro(func):
        """Not using async.Semaphore to aovid use nest-asyncio"""
        __current_size = 0

        @wraps(func)
        async def wait_func(*args, **kwargs):
            nonlocal __current_size
            # 如果当前调用数量达到最大值,等待一段时间后再检查
            while __current_size >= max_size:
                await asyncio.sleep(waitting_time)
            __current_size += 1
            result = await func(*args, **kwargs)
            __current_size -= 1
            return result

        return wait_func

    return final_decro

wrap_embedding_func_with_attrs

用属性包装一个函数。

该函数返回一个装饰器,可用于为一个函数添加额外的属性参数。这些属性

被用于在函数执行时提供或记录额外的信息,比如函数的来源、类型等。

这个在_llm.py中调用,是给embedding函数添加了embedding_dim=1536, max_token_size=8192属性

2.4 base.py

QueryParam

查询参数类,用于定义查询时的各种配置选项。

class QueryParam:
    mode: Literal["local", "global", "naive"] = "global"
    only_need_context: bool = False
    response_type: str = "Multiple Paragraphs"
    level: int = 2
    top_k: int = 20
    # naive search
    naive_max_token_for_text_unit = 12000
    # local search
    local_max_token_for_text_unit: int = 4000  # 12000 * 0.33
    local_max_token_for_local_context: int = 4800  # 12000 * 0.4
    local_max_token_for_community_report: int = 3200  # 12000 * 0.27
    local_community_single_one: bool = False
    # global search
    global_min_community_rating: float = 0
    global_max_consider_community: float = 512
    global_max_token_for_community_report: int = 16384
    global_special_community_map_llm_kwargs: dict = field(
        default_factory=lambda: {"response_format": {"type": "json_object"}}
    )

TextChunkSchema,SingleCommunitySchema

文本块,社区存储结构

class CommunitySchema

添加了字符串格式以及json格式报告的社区class,在_storage.py中有更详细的定义

class StorageNameSpace

存储命名空间类,用于管理存储操作。有namespace和global_config两个属性。是后面几个类的基础父类。

class BaseVectorStorage

基础向量存储类,继承自 StorageNameSpace。补充了embedding_func和meta_fields属性。用来定义各种向量存储,比如entities_vdb: BaseVectorStorage。

方法:

query: 查询方法,具体函数在_storage.py中,下同。

upsert: 插入或更新方法。

class BaseKVStorage

基础键值存储类,继承自 StorageNameSpace。用来定义各种键值存储,比如

community_reports: BaseKVStorage[CommunitySchema],

text_chunks_db: BaseKVStorage[TextChunkSchema]。

方法:

all_keys: 获取所有键。

get_by_id: 根据 ID 获取数据。

get_by_ids: 根据多个 ID 获取数据。

filter_keys: 筛选出不存在的键。

upsert: 插入或更新数据。

drop: 删除整个存储空间。

class BaseGraphStorage

基础图存储类,继承自 StorageNameSpace。用来定义各种图结构的存储,比知识图谱knwoledge_graph_inst: BaseGraphStorage

方法:

has_node: 判断节点是否存在。

has_edge: 判断边是否存在。

node_degree: 获取节点度数。

edge_degree: 获取边的度数。

get_node: 获取节点信息。

get_edge: 获取边信息。

get_node_edges: 获取节点的所有边。

upsert_node: 插入或更新节点。

upsert_edge: 插入或更新边。

clustering: 进行图聚类。

community_schema: 获取社区结构。

embed_nodes: 对节点进行嵌入。

2.5 _storage.py

里面存储了大量数据操作的函数实现,并且定义了知识图谱存储方式的类别,所以代码很长。

class JsonKVStorage

继承自BaseKVStorage,基本相同,并且包含方法的具体实现代码。

class JsonKVStorage(BaseKVStorage):
    def __post_init__(self):
        # 初始化时,根据全局配置确定工作目录,以便获取文件完整路径
        working_dir = self.global_config["working_dir"]
        # 根据命名空间生成特定的 JSON 文件名,用于存储键值数据
        self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
        # 加载存储的数据,如果文件不存在或为空,则初始化为空字典
        self._data = load_json(self._file_name) or {}
        # 打印日志,显示加载的数据条数
        logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
    # 获取所有的键列表
    async def all_keys(self) -> list[str]:
        return list(self._data.keys())
    # 索引操作完成后,将当前数据写入 JSON 文件
    async def index_done_callback(self):
        write_json(self._data, self._file_name)
    # 通过 ID 获取数据
    async def get_by_id(self, id):
        return self._data.get(id, None)
    
    async def get_by_ids(self, ids, fields=None):
        """
        根据ID列表获取数据项。

        参数:
        ids (list): 需要获取数据的ID列表。
        fields (list, 可选): 限制返回数据中的字段。如果未提供,默认为None,将返回完整数据项。

        返回:
        list: 包含按指定ID列表顺序排列的数据项的列表。如果某些ID未找到数据项,则相应位置为None。
        """
        if fields is None:
            return [self._data.get(id, None) for id in ids]
        return [
            (
                # 如果数据项存在,并且ID在_data字典中,则构建一个仅包含fields中字段的新字典
                {k: v for k, v in self._data[id].items() if k in fields}
                # 检查_id是否在数据集中,以避免KeyError
                if self._data.get(id, None)
                else None
            )
            for id in ids
        ]
    # 过滤出不在数据存储中的键列表
    async def filter_keys(self, data: list[str]) -> set[str]:
        return set([s for s in data if s not in self._data])
    # 插入或更新数据
    async def upsert(self, data: dict[str, dict]):
        self._data.update(data)
    # 清空当前存储数据
    async def drop(self):
        self._data = {}

class NanoVectorDBStorage

继承自BaseVectorDBStorage,还是用了类别NanoVectorDB。

class NanoVectorDBStorage(BaseVectorStorage):
    # 余弦相似度阈值,决定返回的结果质量
    cosine_better_than_threshold: float = 0.2

    def __post_init__(self):
        # 初始化向量数据库存储文件和嵌入配置
        self._client_file_name = os.path.join(
            self.global_config["working_dir"], f"vdb_{self.namespace}.json"
        )
        self._max_batch_size = self.global_config["embedding_batch_num"]
        # 初始化向量数据库客户端(NanoVectorDB),并设置嵌入维度
        self._client = NanoVectorDB(
            self.embedding_func.embedding_dim, storage_file=self._client_file_name
        )
        # 从全局配置中获取查询的相似度阈值,或使用默认值
        self.cosine_better_than_threshold = self.global_config.get(
            "query_better_than_threshold", self.cosine_better_than_threshold
        )

    async def upsert(self, data: dict[str, dict]):
        """
        插入或更新向量数据。

        该方法用于将字典形式的数据插入或更新到向量数据库中。数据首先被转换成适合插入的格式,
        然后分批处理,以避免一次性插入过多数据导致的性能问题。之后,使用异步方式计算各批次数据的嵌入向量,
        并将这些向量附加到数据条目中,最后调用客户端的插入或更新方法完成操作。

        参数:
        data: dict[str, dict] - 一个字典,键是数据的唯一标识,值是包含实际数据内容的字典。

        返回:
        插入或更新操作的结果。
        """
        logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
        if not len(data):
            logger.warning("You insert an empty data to vector DB")
            return []
       # 将数据转换为适合插入的列表,并提取内容
        list_data = [
            {
                "__id__": k,
                **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
            }
            for k, v in data.items()
        ]
        contents = [v["content"] for v in data.values()]
        # 将数据按批次进行处理
        batches = [
            contents[i : i + self._max_batch_size]
            for i in range(0, len(contents), self._max_batch_size)
        ]
        # 异步计算各批次数据的嵌入向量
        embeddings_list = await asyncio.gather(
            *[self.embedding_func(batch) for batch in batches]
        )
        # 将所有批次的嵌入向量合并为一个大数组
        embeddings = np.concatenate(embeddings_list)
        # 将计算得到的嵌入向量附加到每个数据条目中
        for i, d in enumerate(list_data):
            d["__vector__"] = embeddings[i]
        # 调用客户端的插入或更新方法,完成数据的插入或更新
        results = self._client.upsert(datas=list_data)
        return results

    async def query(self, query: str, top_k=5):
        """
        根据提供的查询字符串获取最相关的文档。

        此异步方法使用预训练的embedding函数将查询转换为嵌入表示,
        然后在嵌入索引中搜索与查询最相似的文档。

        参数:
        - query: str,用户查询的字符串。
        - top_k: int,返回最相关的文档数量,默认为5。

        返回:
        - 一个列表,包含最相关的文档及其与查询的相似度距离。
        """
        embedding = await self.embedding_func([query])
        embedding = embedding[0]
        results = self._client.query(
            query=embedding,
            top_k=top_k,
            better_than_threshold=self.cosine_better_than_threshold,
        )
        # 整理结果,添加文档id和距离信息
        results = [
            {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
        ]
        return results

    async def index_done_callback(self):
        self._client.save()

HNSWVectorStorage(另一种向量存储方式,这里先略过)

class NetworkXStorage

继承自BaseGraphStorage,存储图结构的数据。

可以先参考这个链接学习一些图(graph)知识。

图的一些基本知识(连通图、连通分量、最小生成树等知识的基本介绍)-CSDN博客

这里社区发现算法还是调用的外部库,和微软开源的GraphRAG一致。

class NetworkXStorage(BaseGraphStorage):
    # 加载并返回一个NetworkX图,存储格式为graphml。
    @staticmethod
    def load_nx_graph(file_name) -> nx.Graph:
        if os.path.exists(file_name):
            return nx.read_graphml(file_name)
        return None
    # 将NetworkX图写入graphml文件。
    @staticmethod
    def write_nx_graph(graph: nx.Graph, file_name):
        logger.info(
            f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
        )
        nx.write_graphml(graph, file_name)

    @staticmethod
    def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
        """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
        Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
        """
        """返回图的最大连通分量,并以稳定的方式排序节点和边。

        参数:
            graph (nx.Graph): 输入的 NetworkX 图。

        返回:
            nx.Graph: 输入图的最大连通分量,以稳定方式排序。
        """
        from graspologic.utils import largest_connected_component

        graph = graph.copy()
        graph = cast(nx.Graph, largest_connected_component(graph))
        node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()}  # type: ignore
        graph = nx.relabel_nodes(graph, node_mapping)
        return NetworkXStorage._stabilize_graph(graph)

    @staticmethod
    def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
        """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
        Ensure an undirected graph with the same relationships will always be read the same way.
        """
        """
        确保无向图以相同的关系读取时始终相同。
        
        参数:
        graph (nx.Graph): 输入的网络图。
        
        返回:
        nx.Graph: 经过稳定处理的网络图。
        """
        # 根据输入图的类型初始化一个新的图实例 
        fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
        # 对节点进行排序,以确保节点的添加顺序一致
        sorted_nodes = graph.nodes(data=True)
        sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
        # 向新图中添加排序后的节点
        fixed_graph.add_nodes_from(sorted_nodes)
        # 将边数据存储到列表中以便后续处理
        edges = list(graph.edges(data=True))
        # 如果图不是有向图,则对边进行排序,以确保边的顺序一致
        if not graph.is_directed():

            def _sort_source_target(edge):
                source, target, edge_data = edge
                if source > target:
                    temp = source
                    source = target
                    target = temp
                return source, target, edge_data

            edges = [_sort_source_target(edge) for edge in edges]
        # 定义获取边的键的函数,用于后续边的排序
        def _get_edge_key(source: Any, target: Any) -> str:
            return f"{source} -> {target}"
        # 对边进行排序
        edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
        # 向新图中添加排序后的边
        fixed_graph.add_edges_from(edges)
        return fixed_graph

    def __post_init__(self):
        """
        初始化函数,用于加载图数据并初始化相关属性。

        该函数首先根据全局配置中的工作目录和实例的命名空间来确定graphml文件的路径。
        然后尝试从该路径加载已存在的图数据。如果图数据存在,则使用NetworkXStorage加载,
        并记录日志信息包括图的节点数和边数。如果图数据不存在,则初始化一个新的无向图。
        最后,初始化两个算法字典,分别用于图的聚类算法和节点嵌入算法。
        """
        self._graphml_xml_file = os.path.join(
            self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
        )
        preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
        if preloaded_graph is not None:
            logger.info(
                f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
            )
        self._graph = preloaded_graph or nx.Graph()
        self._clustering_algorithms = {
            "leiden": self._leiden_clustering,
        }
        self._node_embed_algorithms = {
            "node2vec": self._node2vec_embed,
        }
    # 将当前存储的图写入到GraphML文件中
    async def index_done_callback(self):
        NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)

    async def has_node(self, node_id: str) -> bool:
        """
        异步检查图中是否存在指定节点。下面几个函数类似,就不做标注了,都是调用NetWorkX的函数。

        该方法主要用于确定图结构中是否包含特定的节点。它通过调用底层图对象的has_node方法,
        以高效的方式查询节点是否存在。

        参数:
        node_id (str): 要检查的节点的唯一标识符。

        返回:
        bool: 如果图中存在该节点,则返回True,否则返回False。
        """
        return self._graph.has_node(node_id)

    async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
        return self._graph.has_edge(source_node_id, target_node_id)

    async def get_node(self, node_id: str) -> Union[dict, None]:
        return self._graph.nodes.get(node_id)
    # 获取指定节点的度数。
    async def node_degree(self, node_id: str) -> int:
        # [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0
        return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0
    # 计算两个节点的度数之和
    async def edge_degree(self, src_id: str, tgt_id: str) -> int:
        return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (
            self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0
        )

    async def get_edge(
        self, source_node_id: str, target_node_id: str
    ) -> Union[dict, None]:
        return self._graph.edges.get((source_node_id, target_node_id))

    async def get_node_edges(self, source_node_id: str):
        if self._graph.has_node(source_node_id):
            return list(self._graph.edges(source_node_id))
        return None

    async def upsert_node(self, node_id: str, node_data: dict[str, str]):
        self._graph.add_node(node_id, **node_data)

    async def upsert_edge(
        self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
    ):
        self._graph.add_edge(source_node_id, target_node_id, **edge_data)
    # 根据指定的算法执行聚类操作。
    async def clustering(self, algorithm: str):
        if algorithm not in self._clustering_algorithms:
            raise ValueError(f"Clustering algorithm {algorithm} not supported")
        await self._clustering_algorithms[algorithm]()

    async def community_schema(self) -> dict[str, SingleCommunitySchema]:
        """
        生成社区架构的异步方法。

        该方法计算并组织图中所有节点的社区结构信息。它通过分析节点的集群信息,
        构建不同层级的社区,并计算社区的各种属性,如层次、节点数、边数等。

        Returns:
            一个字典,键为社区键,值为SingleCommunitySchema对象,包含每个社区的详细信息。
        """
        # 初始化结果字典,为每个社区准备默认属性
        results = defaultdict(
            lambda: dict(
                level=None,
                title=None,
                edges=set(),
                nodes=set(),
                chunk_ids=set(),
                occurrence=0.0,
                sub_communities=[],
            )
        )
        max_num_ids = 0
        # 用于存储不同层级的社区
        levels = defaultdict(set)
        # 遍历图中的所有节点,收集社区信息
        for node_id, node_data in self._graph.nodes(data=True):
            # 如果节点没有集群信息,则跳过
            if "clusters" not in node_data:
                continue
            clusters = json.loads(node_data["clusters"])
            this_node_edges = self._graph.edges(node_id)
            # 遍历节点的所有集群信息
            for cluster in clusters:
                # 提取并更新社区的层级信息
                level = cluster["level"]
                cluster_key = str(cluster["cluster"])
                levels[level].add(cluster_key)
                results[cluster_key]["level"] = level
                results[cluster_key]["title"] = f"Cluster {cluster_key}"
                results[cluster_key]["nodes"].add(node_id)
                results[cluster_key]["edges"].update(
                    [tuple(sorted(e)) for e in this_node_edges]
                )
                results[cluster_key]["chunk_ids"].update(
                    node_data["source_id"].split(GRAPH_FIELD_SEP)
                )
                # 计算最大的chunk_ids数量,用于后续计算出现率
                max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))
        # 对层级进行排序,以便后续处理
        ordered_levels = sorted(levels.keys())
        # 构建子社区关系
        for i, curr_level in enumerate(ordered_levels[:-1]):
            next_level = ordered_levels[i + 1]
            this_level_comms = levels[curr_level]
            next_level_comms = levels[next_level]
            # compute the sub-communities by nodes intersection
            # 通过节点交集计算子社区关系
            for comm in this_level_comms:
                results[comm]["sub_communities"] = [
                    c
                    for c in next_level_comms
                    if results[c]["nodes"].issubset(results[comm]["nodes"])
                ]
        # 处理并标准化结果字典中的数据
        for k, v in results.items():
            v["edges"] = list(v["edges"])
            v["edges"] = [list(e) for e in v["edges"]]
            v["nodes"] = list(v["nodes"])
            v["chunk_ids"] = list(v["chunk_ids"])
            v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
        return dict(results)

    def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
        """
        将聚类数据分配到子图

        该方法将给定的聚类数据分配到图中的相应节点。每个节点的聚类信息
        以JSON格式存储在节点的'clusters'属性中。

        参数:
        cluster_data: 字典,包含节点ID和对应的聚类列表。每个聚类是一个字典,
                      包含节点在不同聚类中的信息。
        """
        # 遍历聚类数据字典中的每个节点及其聚类信息
        for node_id, clusters in cluster_data.items():
            # 将节点的聚类信息转换为JSON格式,并存储在图节点的'clusters'属性中
            self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)

    async def _leiden_clustering(self):
        """
        基于Leiden算法的图聚类异步方法。

        本方法使用Leiden算法对图进行聚类,以发现图中的社区结构。
        它从当前图的稳定最大连通组件开始,根据全局配置中的参数进行聚类,
        并将聚类结果转换为子图。
        """
        from graspologic.partition import hierarchical_leiden
        # 获取图的稳定最大连通组件
        graph = NetworkXStorage.stable_largest_connected_component(self._graph)
        # 应用Leiden算法进行层次聚类
        community_mapping = hierarchical_leiden(
            graph,
            max_cluster_size=self.global_config["max_graph_cluster_size"],
            random_seed=self.global_config["graph_cluster_seed"],
        )
        # 准备节点社区字典
        node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)
        # 准备用于跟踪各级别社区数量的字典
        __levels = defaultdict(set)
        # 遍历社区映射,构建节点社区字典和级别跟踪字典
        for partition in community_mapping:
            level_key = partition.level
            cluster_id = partition.cluster
            node_communities[partition.node].append(
                {"level": level_key, "cluster": cluster_id}
            )
            __levels[level_key].add(cluster_id)
        # 转换节点社区字典和级别跟踪字典为最终形式
        node_communities = dict(node_communities)
        __levels = {k: len(v) for k, v in __levels.items()}
        # 记录每级社区的数量信息
        logger.info(f"Each level has communities: {dict(__levels)}")
        # 将聚类数据转换为子图
        self._cluster_data_to_subgraphs(node_communities)
    # 根据指定的算法进行节点嵌入。
    async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
        if algorithm not in self._node_embed_algorithms:
            raise ValueError(f"Node embedding algorithm {algorithm} not supported")
        return await self._node_embed_algorithms[algorithm]()

    async def _node2vec_embed(self):
        """
        异步方法,用于通过node2vec算法嵌入图结构数据。
    
        该方法使用graspologic库的node2vec_embed函数,根据内部图结构和配置参数进行图嵌入。
        它首先调用嵌入函数,然后提取嵌入结果中节点的ID,并返回嵌入向量和节点ID列表。
        """
        from graspologic import embed

        embeddings, nodes = embed.node2vec_embed(
            self._graph,
            **self.global_config["node2vec_params"],
        )

        nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
        return embeddings, nodes_ids

2.6 _op.py

一看就知道,最重要的几个函数基本都封装在这里,所以代码量也来到了作者所说的800行级别(实际是1300行)。

┓( ´∀` )┏

函数列表

chunking_by_token_size

用来进行分块。

def chunking_by_token_size(
    content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
):
    """
    根据token大小对文本进行分块。

    该函数用于将给定的文本内容按照指定的token大小限制进行分块,同时保证相邻块之间有重叠。
    主要用于处理大文本,使其能够适应如OpenAI的GPT系列模型的输入限制。

    参数:
    - content: str, 待分块的文本内容。
    - overlap_token_size: int, 默认128. 相邻文本块之间的重叠token数。
    - max_token_size: int, 默认1024. 每个文本块的最大token数。
    - tiktoken_model: str, 默认"gpt-4o". 用于token化和去token化的tiktoken模型。

    返回:
    - List[Dict[str, Any]], 包含每个文本块的tokens数量、文本内容和块顺序索引的列表。
    """
    # 使用指定的tiktoken模型对文本进行token化
    tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
    # 初始化存储分块结果的列表
    results = []
    # 遍历tokens,根据max_token_size和overlap_token_size进行分块
    for index, start in enumerate(
        range(0, len(tokens), max_token_size - overlap_token_size)
    ):
        # 根据当前分块的起始位置和最大token数限制,获取分块的tokens
        chunk_content = decode_tokens_by_tiktoken(
            tokens[start : start + max_token_size], model_name=tiktoken_model
        )
        # 将当前分块的tokens数量、文本内容和块顺序索引添加到结果列表中
        results.append(
            {
                "tokens": min(max_token_size, len(tokens) - start),
                "content": chunk_content.strip(),
                "chunk_order_index": index,
            }
        )
    return results

extract_entities

内嵌了实体合并的代码,调用了很多_op.py中的函数,这里就不逐个赘述了。

async def extract_entities(
    chunks: dict[str, TextChunkSchema],
    knwoledge_graph_inst: BaseGraphStorage,
    entity_vdb: BaseVectorStorage,
    global_config: dict,
) -> Union[BaseGraphStorage, None]:
    """
    异步函数extract_entities从文本块中提取实体并更新知识图谱。

    参数:
    chunks (dict[str, TextChunkSchema]): 文本块字典,键为文本块标识,值为包含文本块内容的TextChunkSchema对象。
    knwoledge_graph_inst (BaseGraphStorage): 知识图谱实例,用于存储提取的实体和关系。
    entity_vdb (BaseVectorStorage): 实体向量数据库实例,用于存储实体的向量表示。
    global_config (dict): 全局配置字典,包含模型函数、最大迭代次数等参数。

    返回:
    Union[BaseGraphStorage, None]: 更新后的知识图谱实例,如果没有提取到任何实体,则返回None。
    """
    # 从全局配置中提取模型函数和实体提取最大迭代次数
    use_llm_func: callable = global_config["best_model_func"]
    entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
    # 将文本块排序,以便按顺序处理
    ordered_chunks = list(chunks.items())
    # 准备实体提取的提示模板和上下文基础信息
    entity_extract_prompt = PROMPTS["entity_extraction"]
    context_base = dict(
        tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
        record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
        completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
        entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
    )
    continue_prompt = PROMPTS["entiti_continue_extraction"]
    if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
    # 初始化计数器
    already_processed = 0
    already_entities = 0
    already_relations = 0

    async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
        """
        异步函数_process_single_content处理单个文本块的内容,提取实体并更新知识图谱。

        参数:
        chunk_key_dp (tuple[str, TextChunkSchema]): 文本块的键值对,包含文本块标识和内容。

        返回:
        dict: 包含从文本块中提取的可能的节点和边的字典。
        """
        # 初始化非局部变量,用于跟踪处理进度和统计信息
        nonlocal already_processed, already_entities, already_relations
        # 解析文本块的键和数据
        chunk_key = chunk_key_dp[0]
        chunk_dp = chunk_key_dp[1]
        content = chunk_dp["content"]
        # 构建提示信息并调用模型函数提取实体
        hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
        final_result = await use_llm_func(hint_prompt)
        # 构建对话历史并进行多次迭代以提取更多实体
        history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
        for now_glean_index in range(entity_extract_max_gleaning):
            glean_result = await use_llm_func(continue_prompt, history_messages=history)

            history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
            final_result += glean_result
            # 检查是否继续迭代
            if now_glean_index == entity_extract_max_gleaning - 1:
                break

            if_loop_result: str = await use_llm_func(
                if_loop_prompt, history_messages=history
            )
            if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
            if if_loop_result != "yes":
                break
        # 解析结果并更新知识图谱
        records = split_string_by_multi_markers(
            final_result,
            [context_base["record_delimiter"], context_base["completion_delimiter"]],
        )

        maybe_nodes = defaultdict(list)
        maybe_edges = defaultdict(list)
        for record in records:
            record = re.search(r"\((.*)\)", record)
            if record is None:
                continue
            record = record.group(1)
            record_attributes = split_string_by_multi_markers(
                record, [context_base["tuple_delimiter"]]
            )
            # 处理实体提取
            if_entities = await _handle_single_entity_extraction(
                record_attributes, chunk_key
            )
            if if_entities is not None:
                maybe_nodes[if_entities["entity_name"]].append(if_entities)
                continue
            # 处理关系提取
            if_relation = await _handle_single_relationship_extraction(
                record_attributes, chunk_key
            )
            if if_relation is not None:
                maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
                    if_relation
                )
        # 更新处理进度和统计信息
        already_processed += 1
        already_entities += len(maybe_nodes)
        already_relations += len(maybe_edges)
        now_ticks = PROMPTS["process_tickers"][
            already_processed % len(PROMPTS["process_tickers"])
        ]
        print(
            f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
            end="",
            flush=True,
        )
        return dict(maybe_nodes), dict(maybe_edges)
    # 并发处理所有文本块
    # use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
    results = await asyncio.gather(
        *[_process_single_content(c) for c in ordered_chunks]
    )
    print()  # clear the progress bar
    maybe_nodes = defaultdict(list)
    maybe_edges = defaultdict(list)
    for m_nodes, m_edges in results:
        for k, v in m_nodes.items():
            maybe_nodes[k].extend(v)
        for k, v in m_edges.items():
            # it's undirected graph
            maybe_edges[tuple(sorted(k))].extend(v)
    all_entities_data = await asyncio.gather(
        *[
            _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
            for k, v in maybe_nodes.items()
        ]
    )
    await asyncio.gather(
        *[
            _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
            for k, v in maybe_edges.items()
        ]
    )
    if not len(all_entities_data):
        logger.warning("Didn't extract any entities, maybe your LLM is not working")
        return None
    if entity_vdb is not None:
        data_for_vdb = {
            compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
                "content": dp["entity_name"] + dp["description"],
                "entity_name": dp["entity_name"],
            }
            for dp in all_entities_data
        }
        await entity_vdb.upsert(data_for_vdb)
    return knwoledge_graph_inst

generate_community_report

生成社区报告,同样调用了很多其他函数。

async def generate_community_report(
    community_report_kv: BaseKVStorage[CommunitySchema],
    knwoledge_graph_inst: BaseGraphStorage,
    global_config: dict,
):
    """
    异步生成社区报告函数

    该函数用于异步生成社区报告,根据提供的知识图谱实例和全局配置,
    从社区模式中提取数据,并利用语言模型生成社区报告。

    参数:
    - community_report_kv: 实现了社区模式存储的BaseKVStorage实例。
    - knwoledge_graph_inst: 实现了知识图谱存储的BaseGraphStorage实例。
    - global_config: 包含生成社区报告所需的各种配置项的字典。

    返回值:
    无返回值,但会打印处理进度,并将生成的社区报告存储在community_report_kv中。
    """
    # 从全局配置中提取语言模型的额外参数、使用的模型函数和字符串到JSON的转换函数
    llm_extra_kwargs = global_config["special_community_report_llm_kwargs"]
    use_llm_func: callable = global_config["best_model_func"]
    use_string_json_convert_func: callable = global_config[
        "convert_response_to_json_func"
    ]
    # 加载社区报告的提示模板
    community_report_prompt = PROMPTS["community_report"]
    # 从知识图谱实例中获取所有社区模式,并初始化已处理社区计数器
    communities_schema = await knwoledge_graph_inst.community_schema()
    community_keys, community_values = list(communities_schema.keys()), list(
        communities_schema.values()
    )
    already_processed = 0
    # 定义异步函数_form_single_community_report,用于生成单个社区的报告
    async def _form_single_community_report(
        community: SingleCommunitySchema, already_reports: dict[str, CommunitySchema]
    ):
        nonlocal already_processed
        # 为当前社区生成描述文本
        describe = await _pack_single_community_describe(
            knwoledge_graph_inst,
            community,
            max_token_size=global_config["best_model_max_token_size"],
            already_reports=already_reports,
            global_config=global_config,
        )
        # 构建完整的prompt并使用语言模型生成响应
        prompt = community_report_prompt.format(input_text=describe)
        response = await use_llm_func(prompt, **llm_extra_kwargs)
        # 将响应转换为JSON格式,并更新已处理社区计数器和进度打印
        data = use_string_json_convert_func(response)
        already_processed += 1
        now_ticks = PROMPTS["process_tickers"][
            already_processed % len(PROMPTS["process_tickers"])
        ]
        print(
            f"{now_ticks} Processed {already_processed} communities\r",
            end="",
            flush=True,
        )
        return data
    # 按社区级别排序,并从高到低处理
    levels = sorted(set([c["level"] for c in community_values]), reverse=True)
    logger.info(f"Generating by levels: {levels}")
    community_datas = {}
    for level in levels:
        # 筛选当前级别的社区,并并行生成报告
        this_level_community_keys, this_level_community_values = zip(
            *[
                (k, v)
                for k, v in zip(community_keys, community_values)
                if v["level"] == level
            ]
        )
        this_level_communities_reports = await asyncio.gather(
            *[
                _form_single_community_report(c, community_datas)
                for c in this_level_community_values
            ]
        )
        # 将生成的报告整合到社区数据字典中
        community_datas.update(
            {
                k: {
                    "report_string": _community_report_json_to_str(r),
                    "report_json": r,
                    **v,
                }
                for k, r, v in zip(
                    this_level_community_keys,
                    this_level_communities_reports,
                    this_level_community_values,
                )
            }
        )
    print()  # clear the progress bar
    await community_report_kv.upsert(community_datas)

local_query

本地查询。

async def local_query(
    query,
    knowledge_graph_inst: BaseGraphStorage,
    entities_vdb: BaseVectorStorage,
    community_reports: BaseKVStorage[CommunitySchema],
    text_chunks_db: BaseKVStorage[TextChunkSchema],
    query_param: QueryParam,
    global_config: dict,
) -> str:
    """
    根据查询和配置,执行本地查询并返回结果。

    该函数首先根据查询和一系列存储实例构建本地查询上下文。如果查询参数指示只需要上下文,
    则直接返回上下文字符串。如果上下文为空,则返回预定义的查询失败响应。
    否则,将根据上下文和查询参数构造系统提示,并使用全局配置中指定的最佳模型函数生成响应。

    参数:
        query: 查询字符串。
        knowledge_graph_inst: 知识图谱存储实例。
        entities_vdb: 实体向量存储实例。
        community_reports: 社区报告键值存储实例。
        text_chunks_db: 文本块键值存储实例。
        query_param: 查询参数,包含查询类型和响应类型等信息。
        global_config: 全局配置字典,包含模型函数等关键配置。

    返回:
        根据查询和上下文生成的响应字符串。
    """
    # 从全局配置中获取最佳模型函数
    use_model_func = global_config["best_model_func"]
    # 构建本地查询上下文
    context = await _build_local_query_context(
        query,
        knowledge_graph_inst,
        entities_vdb,
        community_reports,
        text_chunks_db,
        query_param,
    )
    # 如果查询参数指示只需要上下文,则直接返回上下文
    if query_param.only_need_context:
        return context
    # 如果上下文为空,则返回查询失败的响应
    if context is None:
        return PROMPTS["fail_response"]
    # 根据预定义模板构造系统提示,包含上下文数据和查询参数中的响应类型
    sys_prompt_temp = PROMPTS["local_rag_response"]
    sys_prompt = sys_prompt_temp.format(
        context_data=context, response_type=query_param.response_type
    )
    # 使用最佳模型函数生成查询响应
    response = await use_model_func(
        query,
        system_prompt=sys_prompt,
    )
    return response

global_query

全局查询,并没有使用map-reduce的方法,而是直接把社区摘要进行了top-k排序。

async def global_query(
    query,
    knowledge_graph_inst: BaseGraphStorage,
    entities_vdb: BaseVectorStorage,
    community_reports: BaseKVStorage[CommunitySchema],
    text_chunks_db: BaseKVStorage[TextChunkSchema],
    query_param: QueryParam,
    global_config: dict,
) -> str:
    """
    异步执行全局查询。

    此函数根据查询参数和配置从知识图谱、实体向量存储、社区报告及文本块数据库中获取相关信息,
    并对这些信息进行处理和排序,最终生成一个综合的回答。

    参数:
    - query: 查询字符串
    - knowledge_graph_inst: 知识图谱存储实例
    - entities_vdb: 实体向量存储实例
    - community_reports: 社区报告存储实例,存储类型为BaseKVStorage
    - text_chunks_db: 文本块存储实例,存储类型为BaseKVStorage
    - query_param: 查询参数对象,包含查询级别、最大考虑社区数等参数
    - global_config: 全局配置字典,包含最佳模型函数等关键配置

    返回:
    - str: 查询的最终回答字符串
    """
    # 获取并筛选社区schema
    community_schema = await knowledge_graph_inst.community_schema()
    community_schema = {
        k: v for k, v in community_schema.items() if v["level"] <= query_param.level
    }
    if not len(community_schema):
        return PROMPTS["fail_response"]
    use_model_func = global_config["best_model_func"]
    # 对社区schema进行排序
    sorted_community_schemas = sorted(
        community_schema.items(),
        key=lambda x: x[1]["occurrence"],
        reverse=True,
    )
    sorted_community_schemas = sorted_community_schemas[
        : query_param.global_max_consider_community
    ]
    community_datas = await community_reports.get_by_ids(
        [k[0] for k in sorted_community_schemas]
    )
    community_datas = [c for c in community_datas if c is not None]
    community_datas = [
        c
        for c in community_datas
        if c["report_json"].get("rating", 0) >= query_param.global_min_community_rating
    ]
    community_datas = sorted(
        community_datas,
        key=lambda x: (x["occurrence"], x["report_json"].get("rating", 0)),
        reverse=True,
    )
    logger.info(f"Revtrieved {len(community_datas)} communities")
    # 映射社区数据并聚合支持点
    map_communities_points = await _map_global_communities(
        query, community_datas, query_param, global_config
    )
    final_support_points = []
    for i, mc in enumerate(map_communities_points):
        for point in mc:
            if "description" not in point:
                continue
            final_support_points.append(
                {
                    "analyst": i,
                    "answer": point["description"],
                    "score": point.get("score", 1),
                }
            )
    final_support_points = [p for p in final_support_points if p["score"] > 0]
    if not len(final_support_points):
        return PROMPTS["fail_response"]
    final_support_points = sorted(
        final_support_points, key=lambda x: x["score"], reverse=True
    )
    final_support_points = truncate_list_by_token_size(
        final_support_points,
        key=lambda x: x["answer"],
        max_token_size=query_param.global_max_token_for_community_report,
    )
    points_context = []
    for dp in final_support_points:
        points_context.append(
            f"""----Analyst {dp['analyst']}----
Importance Score: {dp['score']}
{dp['answer']}
"""
        )
    points_context = "\n".join(points_context)
    if query_param.only_need_context:
        return points_context
    sys_prompt_temp = PROMPTS["global_reduce_rag_response"]
    response = await use_model_func(
        query,
        sys_prompt_temp.format(
            report_data=points_context, response_type=query_param.response_type
        ),
    )
    return response

naive_query

朴素查询,不做重点。

2.7 graphrag.py

这回终于可以看主函数了,但有了前面的铺垫,具体代码就比较简单了。不过里面包含了不少增量式的处理方法,只要理解基本思想和功能看代码也会清楚一些。目标是可以实现输入a文档后,再输入b文档补充到a文档生成的知识图谱中,也就是图扩展。以后可以考虑在此基础上做图更新。

函数列表

always_get_an_event_loop()

def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
    try:
        # If there is already an event loop, use it.
        loop = asyncio.get_event_loop()
    except RuntimeError:
        # If in a sub-thread, create a new event loop.
        logger.info("Creating a new event loop in a sub-thread.")
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
    return loop

class GraphRAG

主要函数类别,实现模块化调用。

insert()通过always_get_an_event_loop()调用异步函数ainsert(),以实现增量添加文本文件,query()和aquery()同理。

ainsert()——分块,实体关系提取,知识图谱聚类,社区报告的函数调用

aquery()——local,global,naive

class GraphRAG:
    # 表示工作目录的路径,默认是根据当前日期时间生成的目录
    working_dir: str = field(
        default_factory=lambda: f"./nano_graphrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
    )
    # graph mode
    # 是否本地存储,默认是本地存储
    enable_local: bool = True
    # 是否启用朴素rag,默认不启用
    enable_naive_rag: bool = False

    # text chunking
    chunk_func: Callable[[str, Optional[int], Optional[int], Optional[str]], List[Dict[str, Union[str, int]]]] = chunking_by_token_size
    # 分块大小
    chunk_token_size: int = 1200
    # 分块重叠数量
    chunk_overlap_token_size: int = 100
    # tiktoken使用的模型名字,默认为gpt-4o
    tiktoken_model_name: str = "gpt-4o"

    # entity extraction
    # 实体提取最大“拾取”次数,也就是反复提取次数,默认不反复提取
    entity_extract_max_gleaning: int = 1
    # 实体摘要最大token数
    entity_summary_to_max_tokens: int = 500

    # graph clustering
    # 社区聚类算法,默认为莱顿算法
    graph_cluster_algorithm: str = "leiden"
    # 最大社区聚类节点数,默认为10
    max_graph_cluster_size: int = 10
    # 社区聚类随机种子,默认为0xDEADBEEF
    graph_cluster_seed: int = 0xDEADBEEF

    # node embedding
    node_embedding_algorithm: str = "node2vec"
    # 如果没有显式传入 node2vec_params,则会调用这个 lambda 函数,自动生成并赋值为这个默认的字典
    node2vec_params: dict = field(
        default_factory=lambda: {
            "dimensions": 1536,
            "num_walks": 10,
            "walk_length": 40,
            "num_walks": 10,
            "window_size": 2,
            "iterations": 3,
            "random_seed": 3,
        }
    )

    # community reports
    # 以json格式返回社区报告
    special_community_report_llm_kwargs: dict = field(
        default_factory=lambda: {"response_format": {"type": "json_object"}}
    )

    # text embedding
    # 默认使用openai的embedding
    embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
    # 批量大小,默认为32
    embedding_batch_num: int = 32
    # 最大并发请求数量
    embedding_func_max_async: int = 16
    # 用于比较查询结果质量的阈值
    query_better_than_threshold: float = 0.2

    # LLM相关调用
    # 需要两种类型的大语言模型(LLM),一种是高性能的,用于规划和回应;另一种是成本较低的,用于总结
    using_azure_openai: bool = False
    best_model_func: callable = gpt_4o_complete
    best_model_max_token_size: int = 32768
    best_model_max_async: int = 16
    cheap_model_func: callable = gpt_4o_mini_complete
    cheap_model_max_token_size: int = 32768
    cheap_model_max_async: int = 16

    # entity extraction
    # 实体提取函数
    entity_extraction_func: callable = extract_entities

    # storage
    # 存储类型设置
    # 键值存储,json,具体定义在_storage.py
    key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
    # 向量数据库存储,具体定义在_storage.py
    vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
    # 为向量数据库存储类提供可选的参数字典
    vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
    # 图数据库存储,默认为NetworkXStorage
    graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
    # 是否启用llm缓存
    enable_llm_cache: bool = True

    # extension
    # 用于传递额外参数的字典
    addon_params: dict = field(default_factory=dict)
    # 用于将 LLM 输出转换为 JSON 的函数
    convert_response_to_json_func: callable = convert_response_to_json

     # 在对象初始化后调用此方法,主要作用为打印配置信息和根据配置进行一些设置调整
    def __post_init__(self):
        # 将对象的属性以键值对的形式打印出来,用于调试和日志记录
        _print_config = ",\n  ".join([f"{k} = {v}" for k, v in asdict(self).items()])
        logger.debug(f"GraphRAG init with param:\n\n  {_print_config}\n")

        # 如果配置了使用Azure OpenAI服务,则调整相关函数为Azure版本
        if self.using_azure_openai:
            if self.best_model_func == gpt_4o_complete:
                self.best_model_func = azure_gpt_4o_complete
            if self.cheap_model_func == gpt_4o_mini_complete:
                self.cheap_model_func = azure_gpt_4o_mini_complete
            if self.embedding_func == openai_embedding:
                self.embedding_func = azure_openai_embedding
            logger.info(
                "Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
            )
        
        # 确保工作目录存在,如果不存在则创建
        if not os.path.exists(self.working_dir):
            logger.info(f"Creating working directory {self.working_dir}")
            os.makedirs(self.working_dir)
        # 初始化存储类实例,用于存储完整文档
        self.full_docs = self.key_string_value_json_storage_cls(
            namespace="full_docs", global_config=asdict(self)
        )
        # 初始化存储类实例,用于存储文本块
        self.text_chunks = self.key_string_value_json_storage_cls(
            namespace="text_chunks", global_config=asdict(self)
        )
        # 根据配置初始化LLM响应缓存,如果配置为启用则创建缓存实例
        self.llm_response_cache = (
            self.key_string_value_json_storage_cls(
                namespace="llm_response_cache", global_config=asdict(self)
            )
            if self.enable_llm_cache
            else None
        )
        # 初始化存储类实例,用于存储社区报告
        self.community_reports = self.key_string_value_json_storage_cls(
            namespace="community_reports", global_config=asdict(self)
        )
        # 初始化图存储类实例,用于存储块实体关系图
        self.chunk_entity_relation_graph = self.graph_storage_cls(
            namespace="chunk_entity_relation", global_config=asdict(self)
        )
        # 限制embedding函数的异步调用次数
        self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
            self.embedding_func
        )
        # 根据配置初始化向量数据库存储类实例,用于存储实体
        self.entities_vdb = (
            self.vector_db_storage_cls(
                namespace="entities",
                global_config=asdict(self),
                embedding_func=self.embedding_func,
                meta_fields={"entity_name"},
            )
            if self.enable_local
            else None
        )
        # 根据配置初始化向量数据库存储类实例,用于存储块
        self.chunks_vdb = (
            self.vector_db_storage_cls(
                namespace="chunks",
                global_config=asdict(self),
                embedding_func=self.embedding_func,
            )
            if self.enable_naive_rag
            else None
        )
        # 限制最佳模型函数的异步调用次数,并为其配置哈希键值存储
        self.best_model_func = limit_async_func_call(self.best_model_max_async)(
            partial(self.best_model_func, hashing_kv=self.llm_response_cache)
        )
        # 限制廉价模型函数的异步调用次数,并为其配置哈希键值存储
        self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)(
            partial(self.cheap_model_func, hashing_kv=self.llm_response_cache)
        )

    def insert(self, string_or_strings):
        loop = always_get_an_event_loop()
        return loop.run_until_complete(self.ainsert(string_or_strings))

    def query(self, query: str, param: QueryParam = QueryParam()):
        loop = always_get_an_event_loop()
        return loop.run_until_complete(self.aquery(query, param))

    async def aquery(self, query: str, param: QueryParam = QueryParam()):
        """
        异步查询函数,根据指定的查询模式执行相应的查询。

        该函数支持三种查询模式:本地模式 ("local")、全局模式 ("global") 和朴素模式 ("naive")。
        根据提供的参数确定使用哪种查询模式,并在查询完成后执行查询结束的钩子函数。

        参数:
        - query: str, 要查询的字符串。
        - param: QueryParam, 查询参数对象,包含查询模式和其他查询相关的参数。

        返回:
        - response: 从查询函数返回的响应。
        
        异常:
        - ValueError: 当尝试在不支持的模式下执行查询时抛出。
        """
        if param.mode == "local" and not self.enable_local:
            raise ValueError("enable_local is False, cannot query in local mode")
        if param.mode == "naive" and not self.enable_naive_rag:
            raise ValueError("enable_naive_rag is False, cannot query in local mode")
        # 根据查询模式执行相应的查询函数
        if param.mode == "local":
            response = await local_query(
                query,
                self.chunk_entity_relation_graph,
                self.entities_vdb,
                self.community_reports,
                self.text_chunks,
                param,
                asdict(self),
            )
        elif param.mode == "global":
            response = await global_query(
                query,
                self.chunk_entity_relation_graph,
                self.entities_vdb,
                self.community_reports,
                self.text_chunks,
                param,
                asdict(self),
            )
        elif param.mode == "naive":
            response = await naive_query(
                query,
                self.chunks_vdb,
                self.text_chunks,
                param,
                asdict(self),
            )
        else:
            raise ValueError(f"Unknown mode {param.mode}")
        await self._query_done()
        return response

    async def ainsert(self, string_or_strings):
        """
        异步插入字符串或字符串列表到文档和片段数据库中,并更新知识图谱和社区报告。

        参数:
            string_or_strings (str 或 List[str]): 要插入的字符串或字符串列表。
        """
        try:
            # 如果输入是一个字符串,将其转换为列表
            if isinstance(string_or_strings, str):
                string_or_strings = [string_or_strings]
            # ---------- new docs
            # 将字符串或字符串列表string_or_strings中的每个元素去除首尾空白后,作为文档内容。
            # 计算其MD5哈希值并添加前缀doc-作为键,内容本身作为值,生成一个新的字典new_docs。
            new_docs = {
                compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
                for c in string_or_strings
            }
            # 筛选出需要添加的新文档ID
            _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
            # 根据筛选结果更新新文档字典
            new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
            # 如果没有新文档需要添加,记录日志并返回
            if not len(new_docs):
                logger.warning(f"All docs are already in the storage")
                return
             # 记录插入新文档的日志
            logger.info(f"[New Docs] inserting {len(new_docs)} docs")

            # ---------- chunking
            inserting_chunks = {}
            for doc_key, doc in new_docs.items():
                # 为每个文档生成片段
                chunks = {
                    compute_mdhash_id(dp["content"], prefix="chunk-"): {
                        **dp,
                        "full_doc_id": doc_key,
                    }
                    for dp in self.chunk_func(
                        doc["content"],
                        overlap_token_size=self.chunk_overlap_token_size,
                        max_token_size=self.chunk_token_size,
                        tiktoken_model=self.tiktoken_model_name,
                    )
                }
                inserting_chunks.update(chunks)
            # 筛选出需要添加的新片段ID
            _add_chunk_keys = await self.text_chunks.filter_keys(
                list(inserting_chunks.keys())
            )
            # 根据筛选结果更新新片段字典
            inserting_chunks = {
                k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
            }
            # 如果没有新片段需要添加,记录日志并返回
            if not len(inserting_chunks):
                logger.warning(f"All chunks are already in the storage")
                return
            logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
            # 如果启用了简单的RAG,则插入片段到相应的数据库
            if self.enable_naive_rag:
                logger.info("Insert chunks for naive RAG")
                await self.chunks_vdb.upsert(inserting_chunks)
            # 由于目前不支持增量更新社区,因此删除所有现有的社区报告
            # TODO: no incremental update for communities now, so just drop all
            await self.community_reports.drop()

            # ---------- extract/summary entity and upsert to graph
            # 提取新实体和关系,并更新到知识图谱中
            logger.info("[Entity Extraction]...")
            maybe_new_kg = await self.entity_extraction_func(
                inserting_chunks,
                knwoledge_graph_inst=self.chunk_entity_relation_graph,
                entity_vdb=self.entities_vdb,
                global_config=asdict(self),
            )
            if maybe_new_kg is None:
                logger.warning("No new entities found")
                return
            self.chunk_entity_relation_graph = maybe_new_kg
            # ---------- update clusterings of graph
            # 更新知识图谱的聚类
            logger.info("[Community Report]...")
            await self.chunk_entity_relation_graph.clustering(
                self.graph_cluster_algorithm
            )
            # 生成并更新社区报告
            await generate_community_report(
                self.community_reports, self.chunk_entity_relation_graph, asdict(self)
            )

            # ---------- commit upsertings and indexing
            # 提交所有更新和索引操作
            await self.full_docs.upsert(new_docs)
            await self.text_chunks.upsert(inserting_chunks)
        finally:
            await self._insert_done()

    async def _insert_done(self):
        tasks = []
        for storage_inst in [
            self.full_docs,
            self.text_chunks,
            self.llm_response_cache,
            self.community_reports,
            self.entities_vdb,
            self.chunks_vdb,
            self.chunk_entity_relation_graph,
        ]:
            if storage_inst is None:
                continue
            tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
        await asyncio.gather(*tasks)

    async def _query_done(self):
        """
        异步函数,用于在查询完成后,确保所有相关的存储实例的索引更新操作完成。
        
        1. 遍历预定义的存储实例列表(当前只有一个llm_response_cache)。
        2. 对于每个非空的存储实例,添加其索引更新完成的回调函数到任务列表。
        3. 使用asyncio.gather等待所有添加到任务列表中的回调函数执行完成。
        
        该函数的主要作用是确保在查询完成后,所有配置的存储实例都完成了它们的索引更新操作,
        这对于保持数据一致性和完整性至关重要。
        """
        tasks = []
        for storage_inst in [self.llm_response_cache]:
            if storage_inst is None:
                continue
            # 将存储实例的索引更新完成的回调函数添加到任务列表
            tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
        await asyncio.gather(*tasks)

3.使用测试

这里我采用本地embedding和调用kimi作为大模型来测试结果,根据官方例子编写了如下的代码。

import os
import sys
import logging
from openai import AsyncOpenAI
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
from sentence_transformers import SentenceTransformer
import numpy as np
sys.path.append("..")

logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.INFO)

API_KEY = "填你自己的"
MODEL = "moonshot-v1-32k"
URL = "https://api.moonshot.cn/v1"
QUESTION = "请概括故事的主要情节并分析故事的主旨。"
WORKING_DIR = "./workspace"
FILEPATH = "./txt/追逐雪的人.txt"
EMBED_MODEL = SentenceTransformer(
    "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
)

async def model_if_cache(
    prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
    openai_async_client = AsyncOpenAI(
        api_key=API_KEY, base_url=URL
    )
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})

    # Get the cached response if having-------------------
    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
    messages.extend(history_messages)
    messages.append({"role": "user", "content": prompt})
    if hashing_kv is not None:
        args_hash = compute_args_hash(MODEL, messages)
        if_cache_return = await hashing_kv.get_by_id(args_hash)
        if if_cache_return is not None:
            return if_cache_return["return"]
    # -----------------------------------------------------

    response = await openai_async_client.chat.completions.create(
        model=MODEL, messages=messages, **kwargs
    )

    # Cache the response if having-------------------
    if hashing_kv is not None:
        await hashing_kv.upsert(
            {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
        )
    # -----------------------------------------------------
    return response.choices[0].message.content


def remove_if_exist(file):
    if os.path.exists(file):
        os.remove(file)


def query():
    rag = GraphRAG(
        working_dir=WORKING_DIR,
        best_model_func=model_if_cache,
        cheap_model_func=model_if_cache,
        embedding_func=local_embedding
    )
    print(
        rag.query(
            QUESTION, param=QueryParam(mode="global")
        )
    )


@wrap_embedding_func_with_attrs(
    embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
    max_token_size=EMBED_MODEL.max_seq_length,
)

async def local_embedding(texts: list[str]) -> np.ndarray:
    return EMBED_MODEL.encode(texts, normalize_embeddings=True)

def insert():
    from time import time

    with open(FILEPATH, encoding="utf-8-sig") as f:
        FAKE_TEXT = f.read()

    remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
    remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
    remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
    remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
    remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")

    rag = GraphRAG(
        working_dir=WORKING_DIR,
        enable_llm_cache=True,
        best_model_func=model_if_cache,
        cheap_model_func=model_if_cache,
        embedding_func=local_embedding
    )
    start = time()
    rag.insert(FAKE_TEXT)
    print("indexing time:", time() - start)


if __name__ == "__main__":
    insert()
    query()

回答结果如下

### 故事主要情节概括

故事主要围绕两个核心人物——"她"和"我"——以及他们的宠物"花卷"展开。这两个人物通过共同的活动、
深入的对话和相互关怀来加深彼此的情感联系。他们一起经历了各种活动和情感交流,包括参观博物馆、
讨论艺术作品,以及参与社区中的艺术和教育活动。宠物"花卷"在故事中扮演了情感纽带的角色,增强了
两位主角之间的联系,并且是他们共同关心和照顾的对象。

### 故事主旨分析

1. **人际关系与情感联系**:故事的主旨在于探索人与人之间的深层情感联系,以及在共同经历和相互
支持中成长和发现生活的意义。"她"和"我"之间的关系是故事的核心,他们通过亲密的旅程,强调了在面
对生活中的挑战和困难时,人与人之间的相互支持和关怀的重要性。

2. **社区互动与个人成长**:故事还展现了社区互动和个人成长的重要性。一个男孩在社区中的互动和
成长经历,涉及学校活动、艺术任务和个人关系,表明社区成员之间的互动对个人发展具有重要影响。艺
术和教育活动在促进社区凝聚力和个人发展中发挥了重要作用。

3. **艺术与美的欣赏**:故事中包含了对艺术和美的欣赏,两位主角在参观博物馆和讨论艺术作品时展
现了他们的情感智慧和共同的感性。艺术活动,特别是海报着色任务,是连接社区成员的重要活动,促进
社区参与和艺术表达。

4. **教育与权威角色**:老师在社区中具有双重角色,既是权威人物,又是教育活动的促进者,管理课
堂行为并支持艺术任务的完成。这表明教育者在塑造社区文化和促进个人成长中扮演着关键角色。

5. **导师与钦佩动态**:男孩与俱乐部领袖的关系表明社区内存在强烈的导师或钦佩动态,影响男孩在
社区中的动机和参与度。这种关系对个人成长和社区凝聚力具有积极影响。

综上所述,故事通过"她"、"我"和宠物"花卷"的亲密旅程,以及他们在社区中的互动和成长经历,探讨了
人际关系、社区互动、艺术欣赏、教育角色和导师动态等多重主题,展现了人与人之间的情感联系、个人
成长和社区凝聚力的重要性。

生成的文件均为json格式,知识图谱是graphml格式。

​​​​​​​

构建的知识图谱如下图。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值