"""
参考:https://microsoft.github.io/graphrag/posts/get_started/
1. 初始化家目录:python -m graphrag.index --init --root ./ragtest
2. 初始化索引:python -m graphrag.index --root ./ragtest

脚本需要放置在ragtest目录下运行
"""

import os
import re
from pathlib import Path
from typing import cast, Union, Tuple

import pandas as pd

from graphrag.config import (
    GraphRagConfig,
    create_graphrag_config,
)
from graphrag.index.progress import PrintProgressReporter
from graphrag.query.input.loaders.dfs import (
    store_entity_semantic_embeddings,
)
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
from graphrag.query.factories import get_local_search_engine
from graphrag.query.indexer_adapters import (
    read_indexer_covariates,
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_reports,
    read_indexer_text_units,
)

reporter = PrintProgressReporter("")


class LocalSearchEngine:

    """
    根据官方代码适当调整:代码启动加载search_agent避免重复加载,对外仅暴露一个调用接口
    response_type 返回: Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report
    """
    def __init__(self, data_dir: Union[str, None], root_dir: Union[str, None]):
        self.data_dir, self.root_dir, self.config = self._configure_paths_and_settings(
            data_dir, root_dir
        )
        self.description_embedding_store = self._get_embedding_description_store()
        self.agent = self.search_agent(
            community_level=2, response_type="Single Paragraph"
        )

    def _configure_paths_and_settings(
        self, data_dir: Union[str, None], root_dir: Union[str, None]
    ) -> Tuple[str, Union[str, None], GraphRagConfig]:
        if data_dir is None and root_dir is None:
            msg = "Either data_dir or root_dir must be provided."
            raise ValueError(msg)
        if data_dir is None:
            data_dir = self._infer_data_dir(cast(str, root_dir))
        config = self._create_graphrag_config(root_dir, data_dir)
        return data_dir, root_dir, config

    @staticmethod
    def _infer_data_dir(root: str) -> str:
        output = Path(root) / "output"
        if output.exists():
            folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True)
            if folders:
                folder = folders[0]
                return str((folder / "artifacts").absolute())
        msg = f"Could not infer data directory from root={root}"
        raise ValueError(msg)

    def _create_graphrag_config(
        self, root: Union[str, None], data_dir: Union[str, None]
    ) -> GraphRagConfig:
        return self._read_config_parameters(cast(str, root or data_dir))

    @staticmethod
    def _read_config_parameters(root: str) -> GraphRagConfig:
        _root = Path(root)
        settings_yaml = _root / "settings.yaml"
        if not settings_yaml.exists():
            settings_yaml = _root / "settings.yml"
        settings_json = _root / "settings.json"

        if settings_yaml.exists():
            reporter.info(f"Reading settings from {settings_yaml}")
            with settings_yaml.open("rb") as file:
                import yaml

                data = yaml.safe_load(
                    file.read().decode(encoding="utf-8", errors="strict")
                )
                return create_graphrag_config(data, root)

        if settings_json.exists():
            reporter.info(f"Reading settings from {settings_json}")
            with settings_json.open("rb") as file:
                import json

                data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
                return create_graphrag_config(data, root)

        reporter.info("Reading settings from environment variables")
        return create_graphrag_config(root_dir=root)

    @staticmethod
    def _get_embedding_description_store(
        vector_store_type: str = VectorStoreType.LanceDB, config_args: dict = None
    ):
        if not config_args:
            config_args = {}

        config_args.update(
            {
                "collection_name": config_args.get(
                    "query_collection_name",
                    config_args.get("collection_name", "description_embedding"),
                ),
            }
        )

        description_embedding_store = VectorStoreFactory.get_vector_store(
            vector_store_type=vector_store_type, kwargs=config_args
        )

        description_embedding_store.connect(**config_args)
        return description_embedding_store

    def search_agent(self, community_level: int, response_type: str):
        """获取搜索引擎"""
        data_path = Path(self.data_dir)

        final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
        final_community_reports = pd.read_parquet(
            data_path / "create_final_community_reports.parquet"
        )
        final_text_units = pd.read_parquet(
            data_path / "create_final_text_units.parquet"
        )
        final_relationships = pd.read_parquet(
            data_path / "create_final_relationships.parquet"
        )
        final_entities = pd.read_parquet(data_path / "create_final_entities.parquet")
        final_covariates_path = data_path / "create_final_covariates.parquet"
        final_covariates = (
            pd.read_parquet(final_covariates_path)
            if final_covariates_path.exists()
            else None
        )

        vector_store_args = (
            self.config.embeddings.vector_store
            if self.config.embeddings.vector_store
            else {}
        )
        vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

        description_embedding_store = self._get_embedding_description_store(
            vector_store_type=vector_store_type,
            config_args=vector_store_args,
        )
        entities = read_indexer_entities(final_nodes, final_entities, community_level)
        store_entity_semantic_embeddings(
            entities=entities, vectorstore=description_embedding_store
        )
        covariates = (
            read_indexer_covariates(final_covariates)
            if final_covariates is not None
            else []
        )

        return get_local_search_engine(
            self.config,
            reports=read_indexer_reports(
                final_community_reports, final_nodes, community_level
            ),
            text_units=read_indexer_text_units(final_text_units),
            entities=entities,
            relationships=read_indexer_relationships(final_relationships),
            covariates={"claims": covariates},
            description_embedding_store=description_embedding_store,
            response_type=response_type,
        )

    def run_search(self, query: str):
        """
        搜索入口
        :param query: 问题
        :return:
        """
        result = self.agent.search(query=query)
        return self.remove_sources(result.response)

    @staticmethod
    def remove_sources(text):
        """
        使用正则表达式匹配 [Data: Sources (82, 14, 42, 98)] 这种格式的字符串
        :param text:
        :return:
        """
        cleaned_text = re.sub(r'\[Data: [^]]+\]', '', text)
        return cleaned_text


# Example usage
BASEDIR = os.path.dirname(__file__)  # Set your base directory path here

local_search_engine = LocalSearchEngine(data_dir=None, root_dir=BASEDIR)
if __name__ == '__main__':
    local_res = local_search_engine.run_search(
        query="如何添加设备",
    )
    print(local_res)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.
  • 102.
  • 103.
  • 104.
  • 105.
  • 106.
  • 107.
  • 108.
  • 109.
  • 110.
  • 111.
  • 112.
  • 113.
  • 114.
  • 115.
  • 116.
  • 117.
  • 118.
  • 119.
  • 120.
  • 121.
  • 122.
  • 123.
  • 124.
  • 125.
  • 126.
  • 127.
  • 128.
  • 129.
  • 130.
  • 131.
  • 132.
  • 133.
  • 134.
  • 135.
  • 136.
  • 137.
  • 138.
  • 139.
  • 140.
  • 141.
  • 142.
  • 143.
  • 144.
  • 145.
  • 146.
  • 147.
  • 148.
  • 149.
  • 150.
  • 151.
  • 152.
  • 153.
  • 154.
  • 155.
  • 156.
  • 157.
  • 158.
  • 159.
  • 160.
  • 161.
  • 162.
  • 163.
  • 164.
  • 165.
  • 166.
  • 167.
  • 168.
  • 169.
  • 170.
  • 171.
  • 172.
  • 173.
  • 174.
  • 175.
  • 176.
  • 177.
  • 178.
  • 179.
  • 180.
  • 181.
  • 182.
  • 183.
  • 184.
  • 185.
  • 186.
  • 187.
  • 188.
  • 189.
  • 190.
  • 191.
  • 192.
  • 193.
  • 194.
  • 195.
  • 196.
  • 197.
  • 198.
  • 199.
  • 200.
  • 201.
  • 202.
  • 203.
  • 204.
  • 205.
  • 206.
  • 207.
  • 208.
  • 209.
  • 210.
  • 211.
  • 212.
  • 213.
  • 214.
  • 215.

搜索方式有global跟loca两种。如果想通过api调用global,修改几个关键字就行。

作者:一石数字欠我15w!!!