微软的开源GraphRAG 项目,github地址为:GitHub - microsoft/graphrag: A modular graph-based Retrieval-Augmented Generation (RAG) system
自动化模板生成部分的代码位于graphrag>prompt_tune中,接下来将详细分析其代码。
__main__.py
导入模块
import argparse
import asyncio
from enum import Enum
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE
from .cli import prompt_tune
argparse
: 用于解析命令行参数。asyncio
: 用于运行异步任务。Enum
: 用于定义枚举类型。- 从
graphrag.prompt_tune.generator
导入MAX_TOKEN_COUNT
。 - 从
graphrag.prompt_tune.loader
导入MIN_CHUNK_SIZE
。 - 从当前包的
cli
模块导入prompt_tune
函数。
枚举类型定义
class DocSelectionType(Enum):
"""The type of document selection to use."""
ALL = "all"
RANDOM = "random"
TOP = "top"
AUTO = "auto"
def __str__(self):
"""Return the string representation of the enum value."""
return self.value
- 定义一个枚举类型
DocSelectionType
,表示文档选择的方法。 - 枚举值包括
ALL
、RANDOM
、TOP
和AUTO
。 - 重写
__str__
方法,返回枚举值的字符串表示
添加命令行参数
parser.add_argument(
"--root",
help="The data project root. Including the config yml, json or .env",
required=False,
type=str,
default=".",
)
parser.add_argument(
"--domain",
help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If left empty, the domain will be inferred from the input data.",
required=False,
default="",
type=str,
)
parser.add_argument(
"--method",
help="The method to select documents, one of: all, random, top or auto",
required=False,
type=DocSelectionType,
choices=list(DocSelectionType),
default=DocSelectionType.RANDOM,
)
parser.add_argument(
"--n_subset_max",
help="The number of text chunks to embed when using auto selection method",
required=False,
type=int,
default=300,
)
parser.add_argument(
"--k",
help="The maximum number of documents to select from each centroid when using auto selection method",
required=False,
type=int,
default=15,
)
parser.add_argument(
"--limit",
help="The limit of files to load when doing random or top selection",
type=int,
required=False,
default=15,
)
parser.add_argument(
"--max-tokens",
help="Max token count for prompt generation",
type=int,
required=False,
default=MAX_TOKEN_COUNT,
)
parser.add_argument(
"--min-examples-required",
help="The minimum number of examples required in entity extraction prompt",
type=int,
required=False,
default=2,
)
parser.add_argument(
"--chunk-size",
help="Max token count for prompt generation",
type=int,
required=False,
default=MIN_CHUNK_SIZE,
)
parser.add_argument(
"--language",
help="Primary language used for inputs and outputs on GraphRAG",
type=str,
required=False,
default="",
)
parser.add_argument(
"--no-entity-types",
help="Use untyped entity extraction generation",
action="store_true",
required=False,
default=False,
)
parser.add_argument(
"--output",
help="Folder to save the generated prompts to",
type=str,
required=False,
default="prompts",
)
- 添加多个命令行参数,每个参数都有帮助信息、类型、是否必需和默认值。各参数的具体作用可以参考官方文档
解析命令行参数并运行异步任务
args = parser.parse_args()
loop = asyncio.get_event_loop()
loop.run_until_complete(
prompt_tune(
args.root,
args.domain,
str(args.method),
args.limit,
args.max_tokens,
args.chunk_size,
args.language,
args.no_entity_types,
args.output,
args.n_subset_max,
args.k,
args.min_examples_required,
)
)
- 解析命令行参数,并将结果存储在
args
对象中。 - 获取当前的事件循环
loop
。 - 使用
run_until_complete
方法运行prompt_tune
异步任务,传递解析的命令行参数。run_until_complete
是 Python 的asyncio
库中的一个方法,用于运行一个异步任务直到完成。
cli.py
这个代码定义了一个命令行接口,用于微调模型并生成提示模板。它主要包含三个异步函数:prompt_tune
、prompt_tune_with_config
和 generate_indexing_prompts
。
导入模块
from pathlib import Path
from datashaper import NoopVerbCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm import load_llm
from graphrag.index.progress import PrintProgressReporter
from graphrag.index.progress.types import ProgressReporter
from graphrag.llm.types.llm_types import CompletionLLM
from graphrag.prompt_tune.generator import (
MAX_TOKEN_COUNT,
create_community_summarization_prompt,
create_entity_extraction_prompt,
create_entity_summarization_prompt,
detect_language,
generate_community_report_rating,
generate_community_reporter_role,
generate_domain,
generate_entity_relationship_examples,
generate_entity_types,
generate_persona,
)
from graphrag.prompt_tune.loader import (
MIN_CHUNK_SIZE,
load_docs_in_chunks,
read_config_parameters,
)
- 导入
Path
用于处理文件路径。 - 导入
NoopVerbCallbacks
,一个空操作回调类。 - 导入
GraphRagConfig
,用于读取配置。 - 导入
load_llm
,用于加载语言模型。 - 导入
PrintProgressReporter
和ProgressReporter
,用于报告进度。 - 导入
CompletionLLM
,表示完成任务的语言模型。 - 导入多个函数和常量,用于生成提示和加载文档。
异步函数 prompt_tune
async def prompt_tune(
root: str,
domain: str,
select: str = "random",
limit: int = 15,
max_tokens: int = MAX_TOKEN_COUNT,
chunk_size: int = MIN_CHUNK_SIZE,
language: str | None = None,
skip_entity_types: bool = False,
output: str = "prompts",
n_subset_max: int = 300,
k: int = 15,
min_examples_required: int = 2,
):
"""Prompt tune the model.
Parameters
----------
- root: The root directory.
- domain: The domain to map the input documents to.
- select: The chunk selection method.
- limit: The limit of chunks to load.
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
- chunk_size: The chunk token size to use.
- skip_entity_types: Skip generating entity types.
- output: The output folder to store the prompts.
- n_subset_max: The number of text chunks to embed when using auto selection method.
- k: The number of documents to select when using auto selection method.
"""
reporter = PrintProgressReporter("")
config = read_config_parameters(root, reporter)
await prompt_tune_with_config(
root,
config,
domain,
select,
limit,
max_tokens,
chunk_size,
language,
skip_entity_types,
output,
reporter,
n_subset_max,
k,
min_examples_required,
)
- 这个函数用于微调模型,生成提示模板。
- 它接受多个参数,如根目录、领域、选择方法、限制、最大令牌数、块大小、语言、是否跳过实体类型、输出目录、嵌入文本块的最大数量和选择的文档数量。
- 创建一个
PrintProgressReporter
对象reporter
,用于报告进度。 - 调用
read_config_parameters
读取配置。 - 调用
prompt_tune_with_config
函数,传递所有参数和配置。
异步函数 prompt_tune_with_config
async def prompt_tune_with_config(
root: str,
config: GraphRagConfig,
domain: str,
select: str = "random",
limit: int = 15,
max_tokens: int = MAX_TOKEN_COUNT,
chunk_size: int = MIN_CHUNK_SIZE,
language: str | None = None,
skip_entity_types: bool = False,
output: str = "prompts",
reporter: ProgressReporter | None = None,
n_subset_max: int = 300,
k: int = 15,
min_examples_required: int = 2,
):
"""Prompt tune the model with a configuration.
Parameters
----------
- root: The root directory.
- config: The GraphRag configuration.
- domain: The domain to map the input documents to.
- select: The chunk selection method.
- limit: The limit of chunks to load.
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
- chunk_size: The chunk token size to use for input text units.
- skip_entity_types: Skip generating entity types.
- output: The output folder to store the prompts.
- reporter: The progress reporter.
- n_subset_max: The number of text chunks to embed when using auto selection method.
- k: The number of documents to select when using auto selection method.
Returns
-------
- None
"""
if not reporter:
reporter = PrintProgressReporter("")
output_path = Path(config.root_dir) / output
doc_list = await load_docs_in_chunks(
root=root,
config=config,
limit=limit,
select_method=select,
reporter=reporter,
chunk_size=chunk_size,
n_subset_max=n_subset_max,
k=k,
)
# Create LLM from config
llm = load_llm(
"prompt_tuning",
config.llm.type,
NoopVerbCallbacks(),
None,
config.llm.model_dump(),
)
await generate_indexing_prompts(
llm,
config,
doc_list,
output_path,
reporter,
domain,
language,
max_tokens,
skip_entity_types,
min_examples_required,
)
- 这个函数与
prompt_tune
类似,但它接受一个配置对象config
。 - 如果没有提供
reporter
,则创建一个新的PrintProgressReporter
对象。 - 计算输出路径
output_path
。 - 调用
load_docs_in_chunks
异步函数加载文档块。 - 调用
load_llm
函数创建语言模型llm
。 - 调用
generate_indexing_prompts
异步函数生成索引提示。
异步函数 generate_indexing_prompts
函数定义和参数
async def generate_indexing_prompts(
llm: CompletionLLM,
config: GraphRagConfig,
doc_list: list[str],
output_path: Path,
reporter: ProgressReporter,
domain: str | None = None,
language: str | None = None,
max_tokens: int = MAX_TOKEN_COUNT,
skip_entity_types: bool = False,
min_examples_required: int = 2,
):
"""Generate indexing prompts.
Parameters
----------
- llm: The LLM model to use.
- config: The GraphRag configuration.
- doc_list: The list of documents to use.
- output_path: The path to store the prompts.
- reporter: The progress reporter.
- domain: The domain to map the input documents to.
- max_tokens: The maximum number of tokens to use on entity extraction prompts
- skip_entity_types: Skip generating entity types.
- min_examples_required: The minimum number of examples required for entity extraction prompts.
"""
llm
: 用于生成提示的语言模型。config
: GraphRag 配置对象。doc_list
: 要使用的文档列表。output_path
: 存储提示的路径。reporter
: 进度报告器。domain
: 输入文档所属的领域。language
: 文档的语言。max_tokens
: 实体提取提示的最大令牌数。skip_entity_types
: 是否跳过生成实体类型。min_examples_required
: 实体提取提示所需的最小示例数。
生成领域
if not domain:
reporter.info("Generating domain...")
domain = await generate_domain(llm, doc_list)
reporter.info(f"Generated domain: {domain}")
- 如果没有提供领域,调用
generate_domain
生成领域。 - 使用
reporter
记录生成的领域。
检测语言
if not language:
reporter.info("Detecting language...")
language = await detect_language(llm, doc_list)
reporter.info(f"Detected language: {language}")
- 如果没有提供语言,调用
detect_language
检测语言。 - 使用
reporter
记录检测到的语言。
生成角色
reporter.info("Generating persona...")
persona = await generate_persona(llm, domain)
reporter.info(f"Generated persona: {persona}")
- 调用
generate_persona
生成角色。 - 使用
reporter
记录生成的角色。
生成社区报告排名描述
reporter.info("Generating community report ranking description...")
community_report_ranking = await generate_community_report_rating(
llm, domain=domain, persona=persona, docs=doc_list
)
reporter.info(
f"Generated community report ranking description: {community_report_ranking}"
)
- 调用
generate_community_report_rating
生成社区报告排名描述。 - 使用
reporter
记录生成的社区报告排名描述。
生成实体类型
entity_types = None
if not skip_entity_types:
reporter.info("Generating entity types")
entity_types = await generate_entity_types(
llm,
domain=domain,
persona=persona,
docs=doc_list,
json_mode=config.llm.model_supports_json or False,
)
reporter.info(f"Generated entity types: {entity_types}")
- 如果没有跳过实体类型,调用
generate_entity_types
生成实体类型。 - 使用
reporter
记录生成的实体类型。
生成实体关系示例
reporter.info("Generating entity relationship examples...")
examples = await generate_entity_relationship_examples(
llm,
persona=persona,
entity_types=entity_types,
docs=doc_list,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
)
reporter.info("Done generating entity relationship examples")
- 调用
generate_entity_relationship_examples
生成实体关系示例。 - 使用
reporter
记录生成的实体关系示例。
生成实体提取提示
reporter.info("Generating entity extraction prompt...")
create_entity_extraction_prompt(
entity_types=entity_types,
docs=doc_list,
examples=examples,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
output_path=output_path,
encoding_model=config.encoding_model,
max_token_count=max_tokens,
min_examples_required=min_examples_required,
)
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")
- 调用
create_entity_extraction_prompt
生成实体提取提示。 - 使用
reporter
记录生成的实体提取提示,并存储在指定的输出路径。
生成实体总结提示
reporter.info("Generating entity summarization prompt...")
create_entity_summarization_prompt(
persona=persona,
language=language,
output_path=output_path,
)
reporter.info(
f"Generated entity summarization prompt, stored in folder {output_path}"
)
- 调用
create_entity_summarization_prompt
生成实体总结提示。 - 使用
reporter
记录生成的实体总结提示,并存储在指定的输出路径。
生成社区报告者角色
reporter.info("Generating community reporter role...")
community_reporter_role = await generate_community_reporter_role(
llm, domain=domain, persona=persona, docs=doc_list
)
reporter.info(f"Generated community reporter role: {community_reporter_role}")
- 调用
generate_community_reporter_role
生成社区报告者角色。 - 使用
reporter
记录生成的社区报告者角色。
生成社区总结提示
reporter.info("Generating community summarization prompt...")
create_community_summarization_prompt(
persona=persona,
role=community_reporter_role,
report_rating_description=community_report_ranking,
language=language,
output_path=output_path,
)
reporter.info(
f"Generated community summarization prompt, stored in folder {output_path}"
)
- 调用
create_community_summarization_prompt
生成社区总结提示。 - 使用
reporter
记录生成的社区总结提示,并存储在指定的输出路径。
总结
这个函数 generate_indexing_prompts
通过调用一系列生成函数,生成各种提示模板,包括领域、语言、角色、社区报告排名描述、实体类型、实体关系示例、实体提取提示、实体总结提示和社区总结提示。每一步都使用 reporter
记录进度,并将生成的提示存储在指定的输出路径。通过这种方式,用户可以自动生成用于索引的提示模板。
生成函数
生成函数在prompt_tune中的generator模块中,接下来以生成角色为例来看一下他是怎么生成的,在persona.py中。里面的核心函数 generate_persona
通过调用语言模型生成一个用于 GraphRAG 提示的角色。它首先格式化任务和生成提示,然后异步调用语言模型生成角色,最后返回生成的角色。通过这种方式,用户可以为特定领域和任务生成自定义的角色,用于微调 GraphRAG 提示。
导入模块
from graphrag.llm.types.llm_types import CompletionLLM
from graphrag.prompt_tune.generator.defaults import DEFAULT_TASK
from graphrag.prompt_tune.prompt import GENERATE_PERSONA_PROMPT
- 从
graphrag.llm.types.llm_types
导入CompletionLLM
,表示完成任务的语言模型。 - 从
graphrag.prompt_tune.generator.defaults
导入DEFAULT_TASK
,这是默认的任务描述。 - 从
graphrag.prompt_tune.prompt
导入GENERATE_PERSONA_PROMPT
,这是生成角色的提示模板。
异步函数 generate_persona
async def generate_persona(
llm: CompletionLLM, domain: str, task: str = DEFAULT_TASK
) -> str:
"""Generate an LLM persona to use for GraphRAG prompts.
Parameters
----------
- llm (CompletionLLM): The LLM to use for generation
- domain (str): The domain to generate a persona for
- task (str): The task to generate a persona for. Default is DEFAULT_TASK
"""
- 这个函数用于生成一个用于 GraphRAG 提示的角色。
- 它接受三个参数:
llm
: 用于生成角色的语言模型。domain
: 要为其生成角色的领域。task
: 要为其生成角色的任务,默认为DEFAULT_TASK
。
格式化任务和生成提示
formatted_task = task.format(domain=domain)
persona_prompt = GENERATE_PERSONA_PROMPT.format(sample_task=formatted_task)
- 使用
domain
格式化task
,生成formatted_task
。 - 使用
formatted_task
格式化GENERATE_PERSONA_PROMPT
,生成persona_prompt
。
调用语言模型生成角色
response = await llm(persona_prompt)
- 异步调用语言模型
llm
,传递生成的提示persona_prompt
。 response
保存语言模型的响应。
返回生成的角色
return str(response.output)
- 返回语言模型响应的输出,转换为字符串。
DEFAULT_TASK与GENERATE_PERSONA_PROMPT
让我们来看一下默认的任务描述与生成角色的提示模板
DEFAULT_TASK = """
Identify the relations and structure of the community of interest, specifically within the {domain} domain.
"""
GENERATE_PERSONA_PROMPT = """
You are an intelligent assistant that helps a human to analyze the information in a text document.
Given a specific type of task and sample text, help the user by generating a 3 to 4 sentence description of an expert who could help solve the problem.
Use a format similar to the following:
You are an expert {{role}}. You are skilled at {{relevant skills}}. You are adept at helping people with {{specific task}}.
task: {sample_task}
persona description:"""
针对不同的生成任务有不同的提示词与输入信息。