"""
参考:https://microsoft.github.io/graphrag/posts/get_started/
1. 初始化家目录:python -m graphrag.index --init --root ./ragtest
2. 初始化索引:python -m graphrag.index --root ./ragtest
脚本需要放置在ragtest目录下运行
"""
import os
import re
from pathlib import Path
from typing import cast, Union, Tuple
import pandas as pd
from graphrag.config import (
GraphRagConfig,
create_graphrag_config,
)
from graphrag.index.progress import PrintProgressReporter
from graphrag.query.input.loaders.dfs import (
store_entity_semantic_embeddings,
)
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
from graphrag.query.factories import get_local_search_engine
from graphrag.query.indexer_adapters import (
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
reporter = PrintProgressReporter("")
class LocalSearchEngine:
"""
根据官方代码适当调整:代码启动加载search_agent避免重复加载,对外仅暴露一个调用接口
response_type 返回: Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report
"""
def __init__(self, data_dir: Union[str, None], root_dir: Union[str, None]):
self.data_dir, self.root_dir, self.config = self._configure_paths_and_settings(
data_dir, root_dir
)
self.description_embedding_store = self._get_embedding_description_store()
self.agent = self.search_agent(
community_level=2, response_type="Single Paragraph"
)
def _configure_paths_and_settings(
self, data_dir: Union[str, None], root_dir: Union[str, None]
) -> Tuple[str, Union[str, None], GraphRagConfig]:
if data_dir is None and root_dir is None:
msg = "Either data_dir or root_dir must be provided."
raise ValueError(msg)
if data_dir is None:
data_dir = self._infer_data_dir(cast(str, root_dir))
config = self._create_graphrag_config(root_dir, data_dir)
return data_dir, root_dir, config
@staticmethod
def _infer_data_dir(root: str) -> str:
output = Path(root) / "output"
if output.exists():
folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True)
if folders:
folder = folders[0]
return str((folder / "artifacts").absolute())
msg = f"Could not infer data directory from root={root}"
raise ValueError(msg)
def _create_graphrag_config(
self, root: Union[str, None], data_dir: Union[str, None]
) -> GraphRagConfig:
return self._read_config_parameters(cast(str, root or data_dir))
@staticmethod
def _read_config_parameters(root: str) -> GraphRagConfig:
_root = Path(root)
settings_yaml = _root / "settings.yaml"
if not settings_yaml.exists():
settings_yaml = _root / "settings.yml"
settings_json = _root / "settings.json"
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open("rb") as file:
import yaml
data = yaml.safe_load(
file.read().decode(encoding="utf-8", errors="strict")
)
return create_graphrag_config(data, root)
if settings_json.exists():
reporter.info(f"Reading settings from {settings_json}")
with settings_json.open("rb") as file:
import json
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
reporter.info("Reading settings from environment variables")
return create_graphrag_config(root_dir=root)
@staticmethod
def _get_embedding_description_store(
vector_store_type: str = VectorStoreType.LanceDB, config_args: dict = None
):
if not config_args:
config_args = {}
config_args.update(
{
"collection_name": config_args.get(
"query_collection_name",
config_args.get("collection_name", "description_embedding"),
),
}
)
description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
)
description_embedding_store.connect(**config_args)
return description_embedding_store
def search_agent(self, community_level: int, response_type: str):
"""获取搜索引擎"""
data_path = Path(self.data_dir)
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
final_community_reports = pd.read_parquet(
data_path / "create_final_community_reports.parquet"
)
final_text_units = pd.read_parquet(
data_path / "create_final_text_units.parquet"
)
final_relationships = pd.read_parquet(
data_path / "create_final_relationships.parquet"
)
final_entities = pd.read_parquet(data_path / "create_final_entities.parquet")
final_covariates_path = data_path / "create_final_covariates.parquet"
final_covariates = (
pd.read_parquet(final_covariates_path)
if final_covariates_path.exists()
else None
)
vector_store_args = (
self.config.embeddings.vector_store
if self.config.embeddings.vector_store
else {}
)
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)
description_embedding_store = self._get_embedding_description_store(
vector_store_type=vector_store_type,
config_args=vector_store_args,
)
entities = read_indexer_entities(final_nodes, final_entities, community_level)
store_entity_semantic_embeddings(
entities=entities, vectorstore=description_embedding_store
)
covariates = (
read_indexer_covariates(final_covariates)
if final_covariates is not None
else []
)
return get_local_search_engine(
self.config,
reports=read_indexer_reports(
final_community_reports, final_nodes, community_level
),
text_units=read_indexer_text_units(final_text_units),
entities=entities,
relationships=read_indexer_relationships(final_relationships),
covariates={"claims": covariates},
description_embedding_store=description_embedding_store,
response_type=response_type,
)
def run_search(self, query: str):
"""
搜索入口
:param query: 问题
:return:
"""
result = self.agent.search(query=query)
return self.remove_sources(result.response)
@staticmethod
def remove_sources(text):
"""
使用正则表达式匹配 [Data: Sources (82, 14, 42, 98)] 这种格式的字符串
:param text:
:return:
"""
cleaned_text = re.sub(r'\[Data: [^]]+\]', '', text)
return cleaned_text
# Example usage
BASEDIR = os.path.dirname(__file__) # Set your base directory path here
local_search_engine = LocalSearchEngine(data_dir=None, root_dir=BASEDIR)
if __name__ == '__main__':
local_res = local_search_engine.run_search(
query="如何添加设备",
)
print(local_res)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
- 102.
- 103.
- 104.
- 105.
- 106.
- 107.
- 108.
- 109.
- 110.
- 111.
- 112.
- 113.
- 114.
- 115.
- 116.
- 117.
- 118.
- 119.
- 120.
- 121.
- 122.
- 123.
- 124.
- 125.
- 126.
- 127.
- 128.
- 129.
- 130.
- 131.
- 132.
- 133.
- 134.
- 135.
- 136.
- 137.
- 138.
- 139.
- 140.
- 141.
- 142.
- 143.
- 144.
- 145.
- 146.
- 147.
- 148.
- 149.
- 150.
- 151.
- 152.
- 153.
- 154.
- 155.
- 156.
- 157.
- 158.
- 159.
- 160.
- 161.
- 162.
- 163.
- 164.
- 165.
- 166.
- 167.
- 168.
- 169.
- 170.
- 171.
- 172.
- 173.
- 174.
- 175.
- 176.
- 177.
- 178.
- 179.
- 180.
- 181.
- 182.
- 183.
- 184.
- 185.
- 186.
- 187.
- 188.
- 189.
- 190.
- 191.
- 192.
- 193.
- 194.
- 195.
- 196.
- 197.
- 198.
- 199.
- 200.
- 201.
- 202.
- 203.
- 204.
- 205.
- 206.
- 207.
- 208.
- 209.
- 210.
- 211.
- 212.
- 213.
- 214.
- 215.
搜索方式有global跟loca两种。如果想通过api调用global,修改几个关键字就行。
作者:一石数字欠我15w!!!