微软GraphRAG的自动化生成提示模板部分代码详解

微软的开源GraphRAG 项目,github地址为:GitHub - microsoft/graphrag: A modular graph-based Retrieval-Augmented Generation (RAG) system

自动化模板生成部分的代码位于graphrag>prompt_tune中,接下来将详细分析其代码。

__main__.py

导入模块

import argparse
import asyncio
from enum import Enum

from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE

from .cli import prompt_tune
  • argparse: 用于解析命令行参数。
  • asyncio: 用于运行异步任务。
  • Enum: 用于定义枚举类型。
  • graphrag.prompt_tune.generator 导入 MAX_TOKEN_COUNT
  • graphrag.prompt_tune.loader 导入 MIN_CHUNK_SIZE
  • 从当前包的 cli 模块导入 prompt_tune 函数。

枚举类型定义

class DocSelectionType(Enum):
    """The type of document selection to use."""

    ALL = "all"
    RANDOM = "random"
    TOP = "top"
    AUTO = "auto"

    def __str__(self):
        """Return the string representation of the enum value."""
        return self.value
  • 定义一个枚举类型 DocSelectionType,表示文档选择的方法。
  • 枚举值包括 ALLRANDOMTOPAUTO
  • 重写 __str__ 方法,返回枚举值的字符串表示

添加命令行参数

parser.add_argument(
    "--root",
    help="The data project root. Including the config yml, json or .env",
    required=False,
    type=str,
    default=".",
)

parser.add_argument(
    "--domain",
    help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If left empty, the domain will be inferred from the input data.",
    required=False,
    default="",
    type=str,
)

parser.add_argument(
    "--method",
    help="The method to select documents, one of: all, random, top or auto",
    required=False,
    type=DocSelectionType,
    choices=list(DocSelectionType),
    default=DocSelectionType.RANDOM,
)

parser.add_argument(
    "--n_subset_max",
    help="The number of text chunks to embed when using auto selection method",
    required=False,
    type=int,
    default=300,
)

parser.add_argument(
    "--k",
    help="The maximum number of documents to select from each centroid when using auto selection method",
    required=False,
    type=int,
    default=15,
)

parser.add_argument(
    "--limit",
    help="The limit of files to load when doing random or top selection",
    type=int,
    required=False,
    default=15,
)

parser.add_argument(
    "--max-tokens",
    help="Max token count for prompt generation",
    type=int,
    required=False,
    default=MAX_TOKEN_COUNT,
)

parser.add_argument(
    "--min-examples-required",
    help="The minimum number of examples required in entity extraction prompt",
    type=int,
    required=False,
    default=2,
)

parser.add_argument(
    "--chunk-size",
    help="Max token count for prompt generation",
    type=int,
    required=False,
    default=MIN_CHUNK_SIZE,
)

parser.add_argument(
    "--language",
    help="Primary language used for inputs and outputs on GraphRAG",
    type=str,
    required=False,
    default="",
)

parser.add_argument(
    "--no-entity-types",
    help="Use untyped entity extraction generation",
    action="store_true",
    required=False,
    default=False,
)

parser.add_argument(
    "--output",
    help="Folder to save the generated prompts to",
    type=str,
    required=False,
    default="prompts",
)
  • 添加多个命令行参数,每个参数都有帮助信息、类型、是否必需和默认值。各参数的具体作用可以参考官方文档

解析命令行参数并运行异步任务

args = parser.parse_args()
loop = asyncio.get_event_loop()

loop.run_until_complete(
    prompt_tune(
        args.root,
        args.domain,
        str(args.method),
        args.limit,
        args.max_tokens,
        args.chunk_size,
        args.language,
        args.no_entity_types,
        args.output,
        args.n_subset_max,
        args.k,
        args.min_examples_required,
    )
)
  • 解析命令行参数,并将结果存储在 args 对象中。
  • 获取当前的事件循环 loop
  • 使用 run_until_complete 方法运行 prompt_tune 异步任务,传递解析的命令行参数。run_until_complete 是 Python 的 asyncio 库中的一个方法,用于运行一个异步任务直到完成。

cli.py

这个代码定义了一个命令行接口,用于微调模型并生成提示模板。它主要包含三个异步函数:prompt_tuneprompt_tune_with_configgenerate_indexing_prompts

导入模块

from pathlib import Path

from datashaper import NoopVerbCallbacks

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm import load_llm
from graphrag.index.progress import PrintProgressReporter
from graphrag.index.progress.types import ProgressReporter
from graphrag.llm.types.llm_types import CompletionLLM
from graphrag.prompt_tune.generator import (
    MAX_TOKEN_COUNT,
    create_community_summarization_prompt,
    create_entity_extraction_prompt,
    create_entity_summarization_prompt,
    detect_language,
    generate_community_report_rating,
    generate_community_reporter_role,
    generate_domain,
    generate_entity_relationship_examples,
    generate_entity_types,
    generate_persona,
)
from graphrag.prompt_tune.loader import (
    MIN_CHUNK_SIZE,
    load_docs_in_chunks,
    read_config_parameters,
)
  • 导入 Path 用于处理文件路径。
  • 导入 NoopVerbCallbacks,一个空操作回调类。
  • 导入 GraphRagConfig,用于读取配置。
  • 导入 load_llm,用于加载语言模型。
  • 导入 PrintProgressReporterProgressReporter,用于报告进度。
  • 导入 CompletionLLM,表示完成任务的语言模型。
  • 导入多个函数和常量,用于生成提示和加载文档。

异步函数 prompt_tune

async def prompt_tune(
    root: str,
    domain: str,
    select: str = "random",
    limit: int = 15,
    max_tokens: int = MAX_TOKEN_COUNT,
    chunk_size: int = MIN_CHUNK_SIZE,
    language: str | None = None,
    skip_entity_types: bool = False,
    output: str = "prompts",
    n_subset_max: int = 300,
    k: int = 15,
    min_examples_required: int = 2,
):
    """Prompt tune the model.

    Parameters
    ----------
    - root: The root directory.
    - domain: The domain to map the input documents to.
    - select: The chunk selection method.
    - limit: The limit of chunks to load.
    - max_tokens: The maximum number of tokens to use on entity extraction prompts.
    - chunk_size: The chunk token size to use.
    - skip_entity_types: Skip generating entity types.
    - output: The output folder to store the prompts.
    - n_subset_max: The number of text chunks to embed when using auto selection method.
    - k: The number of documents to select when using auto selection method.
    """
    reporter = PrintProgressReporter("")
    config = read_config_parameters(root, reporter)

    await prompt_tune_with_config(
        root,
        config,
        domain,
        select,
        limit,
        max_tokens,
        chunk_size,
        language,
        skip_entity_types,
        output,
        reporter,
        n_subset_max,
        k,
        min_examples_required,
    )
  • 这个函数用于微调模型,生成提示模板。
  • 它接受多个参数,如根目录、领域、选择方法、限制、最大令牌数、块大小、语言、是否跳过实体类型、输出目录、嵌入文本块的最大数量和选择的文档数量。
  • 创建一个 PrintProgressReporter 对象 reporter,用于报告进度。
  • 调用 read_config_parameters 读取配置。
  • 调用 prompt_tune_with_config 函数,传递所有参数和配置。

异步函数 prompt_tune_with_config

async def prompt_tune_with_config(
    root: str,
    config: GraphRagConfig,
    domain: str,
    select: str = "random",
    limit: int = 15,
    max_tokens: int = MAX_TOKEN_COUNT,
    chunk_size: int = MIN_CHUNK_SIZE,
    language: str | None = None,
    skip_entity_types: bool = False,
    output: str = "prompts",
    reporter: ProgressReporter | None = None,
    n_subset_max: int = 300,
    k: int = 15,
    min_examples_required: int = 2,
):
    """Prompt tune the model with a configuration.

    Parameters
    ----------
    - root: The root directory.
    - config: The GraphRag configuration.
    - domain: The domain to map the input documents to.
    - select: The chunk selection method.
    - limit: The limit of chunks to load.
    - max_tokens: The maximum number of tokens to use on entity extraction prompts.
    - chunk_size: The chunk token size to use for input text units.
    - skip_entity_types: Skip generating entity types.
    - output: The output folder to store the prompts.
    - reporter: The progress reporter.
    - n_subset_max: The number of text chunks to embed when using auto selection method.
    - k: The number of documents to select when using auto selection method.

    Returns
    -------
    - None
    """
    if not reporter:
        reporter = PrintProgressReporter("")

    output_path = Path(config.root_dir) / output

    doc_list = await load_docs_in_chunks(
        root=root,
        config=config,
        limit=limit,
        select_method=select,
        reporter=reporter,
        chunk_size=chunk_size,
        n_subset_max=n_subset_max,
        k=k,
    )

    # Create LLM from config
    llm = load_llm(
        "prompt_tuning",
        config.llm.type,
        NoopVerbCallbacks(),
        None,
        config.llm.model_dump(),
    )

    await generate_indexing_prompts(
        llm,
        config,
        doc_list,
        output_path,
        reporter,
        domain,
        language,
        max_tokens,
        skip_entity_types,
        min_examples_required,
    )
  • 这个函数与 prompt_tune 类似,但它接受一个配置对象 config
  • 如果没有提供 reporter,则创建一个新的 PrintProgressReporter 对象。
  • 计算输出路径 output_path
  • 调用 load_docs_in_chunks 异步函数加载文档块。
  • 调用 load_llm 函数创建语言模型 llm
  • 调用 generate_indexing_prompts 异步函数生成索引提示。

异步函数 generate_indexing_prompts

函数定义和参数
async def generate_indexing_prompts(
    llm: CompletionLLM,
    config: GraphRagConfig,
    doc_list: list[str],
    output_path: Path,
    reporter: ProgressReporter,
    domain: str | None = None,
    language: str | None = None,
    max_tokens: int = MAX_TOKEN_COUNT,
    skip_entity_types: bool = False,
    min_examples_required: int = 2,
):
    """Generate indexing prompts.

    Parameters
    ----------
    - llm: The LLM model to use.
    - config: The GraphRag configuration.
    - doc_list: The list of documents to use.
    - output_path: The path to store the prompts.
    - reporter: The progress reporter.
    - domain: The domain to map the input documents to.
    - max_tokens: The maximum number of tokens to use on entity extraction prompts
    - skip_entity_types: Skip generating entity types.
    - min_examples_required: The minimum number of examples required for entity extraction prompts.
    """
  • llm: 用于生成提示的语言模型。
  • config: GraphRag 配置对象。
  • doc_list: 要使用的文档列表。
  • output_path: 存储提示的路径。
  • reporter: 进度报告器。
  • domain: 输入文档所属的领域。
  • language: 文档的语言。
  • max_tokens: 实体提取提示的最大令牌数。
  • skip_entity_types: 是否跳过生成实体类型。
  • min_examples_required: 实体提取提示所需的最小示例数。
生成领域
if not domain:
    reporter.info("Generating domain...")
    domain = await generate_domain(llm, doc_list)
    reporter.info(f"Generated domain: {domain}")
  • 如果没有提供领域,调用 generate_domain 生成领域。
  • 使用 reporter 记录生成的领域。
检测语言
if not language:
    reporter.info("Detecting language...")
    language = await detect_language(llm, doc_list)
    reporter.info(f"Detected language: {language}")
  • 如果没有提供语言,调用 detect_language 检测语言。
  • 使用 reporter 记录检测到的语言。
生成角色
reporter.info("Generating persona...")
persona = await generate_persona(llm, domain)
reporter.info(f"Generated persona: {persona}")
  • 调用 generate_persona 生成角色。
  • 使用 reporter 记录生成的角色。
生成社区报告排名描述
reporter.info("Generating community report ranking description...")
community_report_ranking = await generate_community_report_rating(
    llm, domain=domain, persona=persona, docs=doc_list
)
reporter.info(
    f"Generated community report ranking description: {community_report_ranking}"
)
  • 调用 generate_community_report_rating 生成社区报告排名描述。
  • 使用 reporter 记录生成的社区报告排名描述。
生成实体类型
entity_types = None
if not skip_entity_types:
    reporter.info("Generating entity types")
    entity_types = await generate_entity_types(
        llm,
        domain=domain,
        persona=persona,
        docs=doc_list,
        json_mode=config.llm.model_supports_json or False,
    )
    reporter.info(f"Generated entity types: {entity_types}")
  • 如果没有跳过实体类型,调用 generate_entity_types 生成实体类型。
  • 使用 reporter 记录生成的实体类型。
生成实体关系示例
reporter.info("Generating entity relationship examples...")
examples = await generate_entity_relationship_examples(
    llm,
    persona=persona,
    entity_types=entity_types,
    docs=doc_list,
    language=language,
    json_mode=False,  # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
)
reporter.info("Done generating entity relationship examples")
  • 调用 generate_entity_relationship_examples 生成实体关系示例。
  • 使用 reporter 记录生成的实体关系示例。
生成实体提取提示
reporter.info("Generating entity extraction prompt...")
create_entity_extraction_prompt(
    entity_types=entity_types,
    docs=doc_list,
    examples=examples,
    language=language,
    json_mode=False,  # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
    output_path=output_path,
    encoding_model=config.encoding_model,
    max_token_count=max_tokens,
    min_examples_required=min_examples_required,
)
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")
  • 调用 create_entity_extraction_prompt 生成实体提取提示。
  • 使用 reporter 记录生成的实体提取提示,并存储在指定的输出路径。
生成实体总结提示
reporter.info("Generating entity summarization prompt...")
create_entity_summarization_prompt(
    persona=persona,
    language=language,
    output_path=output_path,
)
reporter.info(
    f"Generated entity summarization prompt, stored in folder {output_path}"
)
  • 调用 create_entity_summarization_prompt 生成实体总结提示。
  • 使用 reporter 记录生成的实体总结提示,并存储在指定的输出路径。
生成社区报告者角色
reporter.info("Generating community reporter role...")
community_reporter_role = await generate_community_reporter_role(
    llm, domain=domain, persona=persona, docs=doc_list
)
reporter.info(f"Generated community reporter role: {community_reporter_role}")
  • 调用 generate_community_reporter_role 生成社区报告者角色。
  • 使用 reporter 记录生成的社区报告者角色。
生成社区总结提示
reporter.info("Generating community summarization prompt...")
create_community_summarization_prompt(
    persona=persona,
    role=community_reporter_role,
    report_rating_description=community_report_ranking,
    language=language,
    output_path=output_path,
)
reporter.info(
    f"Generated community summarization prompt, stored in folder {output_path}"
)
  • 调用 create_community_summarization_prompt 生成社区总结提示。
  • 使用 reporter 记录生成的社区总结提示,并存储在指定的输出路径。
总结

这个函数 generate_indexing_prompts 通过调用一系列生成函数,生成各种提示模板,包括领域、语言、角色、社区报告排名描述、实体类型、实体关系示例、实体提取提示、实体总结提示和社区总结提示。每一步都使用 reporter 记录进度,并将生成的提示存储在指定的输出路径。通过这种方式,用户可以自动生成用于索引的提示模板。

生成函数

生成函数在prompt_tune中的generator模块中,接下来以生成角色为例来看一下他是怎么生成的,在persona.py中。里面的核心函数 generate_persona 通过调用语言模型生成一个用于 GraphRAG 提示的角色。它首先格式化任务和生成提示,然后异步调用语言模型生成角色,最后返回生成的角色。通过这种方式,用户可以为特定领域和任务生成自定义的角色,用于微调 GraphRAG 提示。

导入模块

from graphrag.llm.types.llm_types import CompletionLLM
from graphrag.prompt_tune.generator.defaults import DEFAULT_TASK
from graphrag.prompt_tune.prompt import GENERATE_PERSONA_PROMPT
  • graphrag.llm.types.llm_types 导入 CompletionLLM,表示完成任务的语言模型。
  • graphrag.prompt_tune.generator.defaults 导入 DEFAULT_TASK,这是默认的任务描述。
  • graphrag.prompt_tune.prompt 导入 GENERATE_PERSONA_PROMPT,这是生成角色的提示模板。

异步函数 generate_persona

async def generate_persona(
    llm: CompletionLLM, domain: str, task: str = DEFAULT_TASK
) -> str:
    """Generate an LLM persona to use for GraphRAG prompts.

    Parameters
    ----------
    - llm (CompletionLLM): The LLM to use for generation
    - domain (str): The domain to generate a persona for
    - task (str): The task to generate a persona for. Default is DEFAULT_TASK
    """
  • 这个函数用于生成一个用于 GraphRAG 提示的角色。
  • 它接受三个参数:
    • llm: 用于生成角色的语言模型。
    • domain: 要为其生成角色的领域。
    • task: 要为其生成角色的任务,默认为 DEFAULT_TASK

格式化任务和生成提示

formatted_task = task.format(domain=domain)
persona_prompt = GENERATE_PERSONA_PROMPT.format(sample_task=formatted_task)
  • 使用 domain 格式化 task,生成 formatted_task
  • 使用 formatted_task 格式化 GENERATE_PERSONA_PROMPT,生成 persona_prompt

调用语言模型生成角色

response = await llm(persona_prompt)
  • 异步调用语言模型 llm,传递生成的提示 persona_prompt
  • response 保存语言模型的响应。

返回生成的角色

return str(response.output)
  • 返回语言模型响应的输出,转换为字符串。

DEFAULT_TASK与GENERATE_PERSONA_PROMPT

让我们来看一下默认的任务描述与生成角色的提示模板

DEFAULT_TASK = """
Identify the relations and structure of the community of interest, specifically within the {domain} domain.
"""
GENERATE_PERSONA_PROMPT = """
You are an intelligent assistant that helps a human to analyze the information in a text document.
Given a specific type of task and sample text, help the user by generating a 3 to 4 sentence description of an expert who could help solve the problem.
Use a format similar to the following:
You are an expert {{role}}. You are skilled at {{relevant skills}}. You are adept at helping people with {{specific task}}.

task: {sample_task}
persona description:"""

针对不同的生成任务有不同的提示词与输入信息。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值