CRAG讲解

Corrective RAG (CRAG)

Corrective RAG (CRAG)

概述

Corrective RAG (CRAG) 是一种增强型的 RAG(检索增强生成)策略,结合了自我反思和自我评分机制,用于提高检索文档和生成内容的质量。CRAG 通过多步骤的评估和纠正机制,旨在进一步提升回答的相关性和准确性,减少错误信息(如幻觉)。

CRAG 的主要步骤

在论文中,CRAG 采取了以下几个关键步骤:

  1. 评估文档的相关性
    • 如果至少有一个文档超过了相关性阈值,则继续生成回答。
    • 如果所有文档的相关性都低于阈值,或者评分器不确定,则框架会寻求额外的数据源,通过网络搜索来补充检索结果。
  2. 知识精炼(Knowledge Refinement)
    • 在生成回答之前,进行知识精炼。
    • 将文档划分为“知识条块”(knowledge strips)。
    • 对每个条块进行评分,并过滤掉不相关的条块。
  3. 网络搜索补充
    • 使用 Tavily Search 进行网络搜索,以补充检索结果。
    • 使用查询重写(query re-writing)优化查询,以提高网络搜索的效果。

在本实现中,我们将首先跳过知识精炼阶段。如果发现任何文档不相关,则选择使用网络搜索补充检索结果。

系统架构图

系统的图形化表示如下所示:

在这里插入图片描述


设置环境(Setup)

首先,下载所需的包并设置必要的API密钥。

1. 安装必要的包

在Jupyter Notebook或终端中运行以下命令安装所需的包:

pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python

2. 设置API密钥

接下来,设置OpenAI和Tavily的API密钥。以下代码将提示您输入API密钥并将其存储在环境变量中:

import getpass
import os

def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")

_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")

3. 设置LangSmith用于LangGraph开发

创建索引(Create Index)

1. 构建索引

我们首先需要构建一个文档索引,以便后续的检索和生成过程。

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings

# 设置嵌入模型
embd = OpenAIEmbeddings()

# 要索引的文档URL
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

# 使用WebBaseLoader加载文档
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

# 使用RecursiveCharacterTextSplitter拆分文档
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

# 将拆分后的文档添加到向量存储中
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()

解释:

  • WebBaseLoader:从指定的URL递归加载网页内容。
  • RecursiveCharacterTextSplitter:将长文档拆分成较小的块,以便LLM更高效地处理。
  • Chroma:使用向量存储(vectorstore)管理文档的嵌入向量,并提供高效的相似度检索。
  • retriever:将向量存储作为检索器,供LLM调用以获取相关文档。

LLMs 配置

使用Pydantic与LangChain

此部分使用Pydantic v2的BaseModel,需要langchain-core >= 0.3。使用langchain-core < 0.3将导致因混合使用Pydantic v1和v2而出错。

1. 检索评分器(Retrieval Grader)

检索评分器用于评估检索到的文档是否与用户问题相关。

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

# 数据模型
class GradeDocuments(BaseModel):
    """评估检索到的文档相关性的二进制评分。"""

    binary_score: str = Field(
        description="文档是否与问题相关,'yes'或'no'"
    )

# 初始化LLM并绑定结构化输出
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# 定义系统消息模板
system = """你是一个评分员,负责评估检索到的文档是否与用户的问题相关。
这不需要是严格的测试,目标是过滤掉错误的检索结果。
如果文档包含与用户问题相关的关键词或语义含义,请将其评分为相关。
请给出二元评分“yes”或“no”,以指示文档是否与问题相关。"""
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "检索到的文档:\n\n {document} \n\n 用户问题:{question}"),
    ]
)

# 构建评分链
retrieval_grader = grade_prompt | structured_llm_grader

# 示例调用
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}))

# 输出示例
binary_score='yes'

解释:

  • GradeDocuments:定义了评分器的输出结构,包括binary_score字段,值为"yes""no"
  • grade_prompt:定义了用于评估文档相关性的提示模板。
  • retrieval_grader:结合了提示模板和LLM的评分链。
  • 示例调用:评估特定文档是否与用户问题相关。

2. 生成回答节点(Generate)

生成回答节点基于检索到的文档生成最终回答。

from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# 获取提示模板
prompt = hub.pull("rlm/rag-prompt")

# 初始化LLM
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)

# 后处理函数:格式化文档
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# 构建RAG链
rag_chain = prompt | llm | StrOutputParser()

# 运行RAG链
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)

# 输出示例
# The design of generative agents combines LLM with memory, planning, and reflection mechanisms to enable agents to behave conditioned on past experience. Memory stream is a long-term memory module that records a comprehensive list of agents' experience in natural language. Short-term memory is utilized for in-context learning, while long-term memory allows agents to retain and recall information over extended periods.

解释:

  • hub.pull(“rlm/rag-prompt”):从LangChain Hub拉取预定义的RAG提示模板。
  • rag_chain:结合提示模板和LLM,创建一个RAG链。
  • 后处理:将检索到的文档内容格式化为字符串,作为上下文给LLM。
  • generation:基于上下文和用户问题生成的回答。

3. 问题重写器(Question Re-writer)

问题重写器用于优化用户的问题,以提高检索效果。

# 初始化LLM
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)

# 定义系统消息模板
system = """你是一个问题重写器,负责将输入的问题转换为更好的版本,以优化网络搜索的效果。
请查看输入的问题,并尝试推理其潜在的语义意图或含义。"""
re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "这是初始问题:\n\n {question} \n 请制定一个改进后的问题。",
        ),
    ]
)

# 构建问题重写链
question_rewriter = re_write_prompt | llm | StrOutputParser()

# 示例调用
question_rewriter.invoke({"question": question})

# 输出示例
'What is the role of memory in artificial intelligence agents?'

解释:

  • re_write_prompt:定义了用于重写问题的提示模板。
  • question_rewriter:结合了提示模板和LLM的问题重写链。
  • 示例调用:优化用户的原始问题,以提高检索效果。

4. 网络搜索工具(Web Search Tool)

网络搜索工具用于处理与近期事件相关的问题,通过网络搜索获取最新信息。

from langchain_community.tools.tavily_search import TavilySearchResults

# 初始化网络搜索工具
web_search_tool = TavilySearchResults(k=3)

解释:

  • TavilySearchResults:定义了网络搜索工具,设置返回结果的数量为3。
  • web_search_tool:网络搜索工具实例,供后续调用以获取相关信息。

构建图(Construct the Graph)

1. 定义图状态(Define Graph State)

首先,定义图的状态结构,包含问题、生成的回答、是否进行网络搜索以及相关文档列表。

from typing import List
from typing_extensions import TypedDict

class GraphState(TypedDict):
    """
    表示图的状态。

    属性:
        question: 用户问题
        generation: LLM生成的回答
        web_search: 是否进行网络搜索
        documents: 文档列表
    """
    question: str
    generation: str
    web_search: str
    documents: List[str]

解释:

  • GraphState:定义了图的状态结构,包括用户问题(question)、生成的回答(generation)、是否进行网络搜索(web_search)和相关文档列表(documents)。

2. 定义图流程(Define Graph Flow)

构建图的逻辑流程,包括检索、生成、评分和重写等节点。

from langchain.schema import Document

def retrieve(state):
    """
    检索文档

    Args:
        state (dict): 当前图的状态

    Returns:
        dict: 更新状态,包含检索到的文档
    """
    print("---检索---")
    question = state["question"]

    # 调用检索器
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}

def generate(state):
    """
    生成回答

    Args:
        state (dict): 当前图的状态

    Returns:
        dict: 更新状态,包含生成的回答
    """
    print("---生成回答---")
    question = state["question"]
    documents = state["documents"]

    # RAG生成
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

def grade_documents(state):
    """
    确定检索到的文档是否与问题相关

    Args:
        state (dict): 当前图的状态
		retrieval_grader: 检索评分器对象
    Returns:
        dict: 更新后的状态,包含过滤后的相关文档和是否进行网络搜索的标志
    """
    print("---检查文档与问题的相关性---")
    question = state["question"]
    documents = state["documents"]

    # 评分每个文档
    filtered_docs = []
    web_search = "No"
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            print("---评分:文档相关---")
            filtered_docs.append(d)
        else:
            print("---评分:文档不相关---")
            web_search = "Yes"
            continue
    return {"documents": filtered_docs, "question": question, "web_search": web_search}

def transform_query(state):
    """
    转换查询,生成更好的问题

    Args:
        state (dict): 当前图的状态

    Returns:
        dict: 更新状态,包含重新表述的问题
    """
    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state["documents"]

    # 重写问题
    better_question = question_rewriter.invoke({"question": question})
    return {"documents": documents, "question": better_question}

def web_search(state):
    """
    基于重新表述的问题进行网络搜索

    Args:
        state (dict): 当前图的状态

    Returns:
        dict: 更新状态,包含网络搜索结果
    """
    print("---优化问题---")
    question = state["question"]
    documents = state["documents"]

    # 网络搜索
    docs = web_search_tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)

    return {"documents": documents, "question": question}

解释:

  • retrieve:根据用户问题调用检索器,获取相关文档。
  • generate:基于检索到的文档生成回答。
  • grade_documents:评估每个检索到的文档是否与用户问题相关,并筛选出相关文档。如果有任何文档不相关,则标记需要进行网络搜索。
  • transform_query:优化用户的问题,以提高检索效果。
  • web_search:针对不相关的文档,通过网络搜索获取补充信息,并将搜索结果添加到文档列表中。

3. 定义边(Edges)

定义节点之间的连接关系,决定流程的执行顺序。

def decide_to_generate(state):
    """
    决定是否生成回答,或重新生成问题

    Args:
        state (dict): 当前图的状态

    Returns:
        str: 决策结果,决定下一步调用的节点
    """
    print("---评估已评分的文档---")
    web_search = state["web_search"]
    filtered_documents = state["documents"]

    if web_search == "Yes":
        # 有不相关的文档,需要进行网络搜索并重新生成问题
         print("---决策:所有文档与问题不相关,优化问题---")
        return "transform_query"
    else:
        # 有相关文档,生成回答
        print("---决策:生成回答---")
        return "generate"

def grade_generation_v_documents_and_question(state):
    """
    确定生成的回答是否基于文档且回答了问题

    Args:
        state (dict): 当前图的状态

    Returns:
        str: 决策结果,决定下一步调用的节点
    """
    print("---检查幻觉---")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    score = hallucination_grader.invoke(
        {"documents": documents, "generation": generation}
    )
    grade = score.binary_score

    # 检查幻觉
    if grade == "yes":
        print("---决策:生成的回答基于文档---")
        # 检查回答是否解决了问题
        print("---评分生成的回答是否解决问题---")
        score = answer_grader.invoke({"question": question, "generation": generation})
        grade = score.binary_score
        if grade == "yes":
            print("---决策:生成的回答解决了问题---")
            return "useful"
        else:
            print("---决策:生成的回答未解决问题---")
            return "transform_query"
    else:
        print("---决策:生成的回答未基于文档,重试---")
        return "generate"

解释:

  • decide_to_generate:根据文档评分结果,决定是否生成回答或重新转换查询。如果有不相关的文档,则需要进行网络搜索并优化问题。
  • grade_generation_v_documents_and_question:评估生成的回答是否基于检索到的文档且有效回答了用户的问题。如果回答不符合要求,则重新生成或优化问题。

4. 编译图(Compile Graph)

使用StateGraph将所有节点和边连接起来,并编译图。

from langgraph.graph import END, StateGraph, START

workflow = StateGraph(GraphState)

# 定义节点
workflow.add_node("retrieve", retrieve)  # 检索节点
workflow.add_node("grade_documents", grade_documents)  # 评估文档相关性节点
workflow.add_node("generate", generate)  # 生成回答节点
workflow.add_node("transform_query", transform_query)  # 转换查询节点
workflow.add_node("web_search_node", web_search)  # 网络搜索节点

# 定义边
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)

# 编译图
app = workflow.compile()

解释:

  • workflow.add_node:将各个节点添加到图中。
    • “retrieve”:检索节点,负责从向量存储中获取相关文档。
    • “grade_documents”:评估文档相关性节点,筛选相关文档。
    • “generate”:生成回答节点,基于相关文档生成最终回答。
    • “transform_query”:转换查询节点,优化用户问题以提高检索效果。
    • “web_search_node”:网络搜索节点,处理需要通过网络搜索获取的补充信息。
  • workflow.add_edge:定义节点之间的直接连接。
    • START -> “retrieve”:流程从检索节点开始。
    • “retrieve” -> “grade_documents”:检索后评估文档相关性。
    • “grade_documents” -> “transform_query” 或 “generate”:根据评估结果决定下一步。
    • “transform_query” -> “web_search_node”:优化问题后进行网络搜索。
    • “web_search_node” -> “generate”:获取补充信息后生成回答。
    • “generate” -> END:生成回答后结束流程。

使用图(Use the Graph)

1. 导入必要模块

from pprint import pprint

2. 运行

定义输入并通过图进行处理。

# 示例调用1
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
    for key, value in output.items():
        # 节点
        pprint(f"Node '{key}':")
        # 可选:打印每个节点的完整状态
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint("\n---\n")

# 最终生成的回答
pprint(value["generation"])

输出示例:

---RETRIEVE---
Node 'retrieve':
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
Node 'grade_documents':
'\n---\n'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
Node 'generate':
'\n---\n'
('The types of agent memory include Sensory Memory, Short-Term Memory (STM) or '
 'Working Memory, and Long-Term Memory (LTM) with subtypes of Explicit / '
 'declarative memory and Implicit / procedural memory. Sensory memory retains '
 'sensory information briefly, STM stores information for cognitive tasks, and '
 'LTM stores information for a long time with different types of memories.')
Trace:
https://smith.langchain.com/public/f6b1716c-e842-4282-9112-1026b93e246b/r
# 示例调用2
inputs = {"question": "How does the AlphaCodium paper work?"}
for output in app.stream(inputs):
    for key, value in output.items():
        # 节点
        pprint(f"Node '{key}':")
        # 可选:打印每个节点的完整状态
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint("\n---\n")

# 最终生成的回答
pprint(value["generation"])

输出示例:

---RETRIEVE---
Node 'retrieve':
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---
---TRANSFORM QUERY---
Node 'transform_query':
'\n---\n'
---WEB SEARCH---
Node 'web_search_node':
'\n---\n'
---GENERATE---
Node 'generate':
'\n---\n'
Node '__end__':
'\n---\n'
('The AlphaCodium paper functions by proposing a code-oriented iterative flow '
 'that involves repeatedly running and fixing generated code against '
 'input-output tests. Its key mechanisms include generating additional data '
 'like problem reflection and test reasoning to aid the iterative process, as '
 'well as enriching the code generation process. AlphaCodium aims to improve '
 'the performance of Large Language Models on code problems by following a '
 'test-based, multi-stage approach.')

解释:

  • 示例调用1

    • 用户问题被路由到retrieve节点,从向量存储中检索相关文档。
    • grade_documents节点评估每个文档的相关性,筛选出相关文档。
    • decide_to_generate节点决定生成回答。
    • generate节点基于相关文档生成最终回答。
    • 最终回答展示。
  • 示例调用2

    • 用户问题被路由到retrieve节点,从向量存储中检索相关文档。
    • grade_documents节点评估每个文档的相关性,发现部分文档不相关,标记需要进行网络搜索。
    • decide_to_generate节点决定需要优化查询。
    • transform_query节点优化用户问题。
    • web_search_node节点通过网络搜索获取补充信息。
    • generate节点基于补充信息生成最终回答。
    • 最终回答展示。

汇总

# crag.py

import getpass
import os
from typing import List
from typing_extensions import TypedDict
from pprint import pprint

# LangChain 和 LangGraph 相关导入
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain import hub
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from langgraph.graph import END, StateGraph, START
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.schema import Document

# 环境设置函数
def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")

_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")

# 定义图的状态
class GraphState(TypedDict):
    """
    表示图的状态。

    属性:
        question: 用户问题
        generation: LLM生成的回答
        web_search: 是否进行网络搜索
        documents: 文档列表
    """
    question: str
    generation: str
    web_search: str
    documents: List[str]

# 初始化检索器
def initialize_retriever():
    urls = [
        "https://lilianweng.github.io/posts/2023-06-23-agent/",
        "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
        "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
    ]

    # 加载文档
    docs = [WebBaseLoader(url).load() for url in urls]
    docs_list = [item for sublist in docs for item in sublist]

    # 分割文档
    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=250, chunk_overlap=0
    )
    doc_splits = text_splitter.split_documents(docs_list)

    # 添加到向量数据库
    vectorstore = Chroma.from_documents(
        documents=doc_splits,
        collection_name="crag-chroma",
        embedding=OpenAIEmbeddings(),
    )
    retriever = vectorstore.as_retriever()
    return retriever

# 定义评分器的数据模型和函数

# 1. 检索评分器(Retrieval Grader)
class GradeDocuments(BaseModel):
    """评估检索文档相关性的二元评分。"""
    binary_score: str = Field(
        description="文档是否与问题相关,'yes' 或 'no'"
    )

def initialize_retrieval_grader():
    llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
    structured_llm_grader = llm.with_structured_output(GradeDocuments)

    system = """你是一个评分员,负责评估检索到的文档是否与用户的问题相关。
这不需要是严格的测试,目标是过滤掉错误的检索结果。
如果文档包含与用户问题相关的关键词或语义含义,请将其评分为相关。
请给出二元评分“yes”或“no”,以指示文档是否与问题相关。"""
    grade_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "检索到的文档:\n\n {document} \n\n 用户问题:{question}"),
        ]
    )

    retrieval_grader = grade_prompt | structured_llm_grader
    return retrieval_grader

# 2. 幻觉评分器(Hallucination Grader)
class GradeHallucinations(BaseModel):
    """评估回答中是否存在幻觉的二元评分。"""
    binary_score: str = Field(
        description="回答是否基于事实支持,'yes' 或 'no'"
    )

def initialize_hallucination_grader():
    llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
    structured_llm_grader = llm.with_structured_output(GradeHallucinations)

    system = """你是一个评分员,负责评估LLM生成的回答是否基于一组检索到的事实。
请给出二元评分“yes”或“no”。“yes”表示回答是基于这些事实支持的。"""
    hallucination_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "事实集:\n\n {documents} \n\n LLM生成的回答:{generation}"),
        ]
    )

    hallucination_grader = hallucination_prompt | structured_llm_grader
    return hallucination_grader

# 3. 回答评分器(Answer Grader)
class GradeAnswer(BaseModel):
    """评估回答是否解决问题的二元评分。"""
    binary_score: str = Field(
        description="回答是否解决了问题,'yes' 或 'no'"
    )

def initialize_answer_grader():
    llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
    structured_llm_grader = llm.with_structured_output(GradeAnswer)

    system = """你是一个评分员,负责评估一个回答是否解决了用户的问题。
请给出二元评分“yes”或“no”。“yes”表示回答解决了问题。"""
    answer_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "用户问题:\n\n {question} \n\n LLM生成的回答:{generation}"),
        ]
    )

    answer_grader = answer_prompt | structured_llm_grader
    return answer_grader

# 4. 问题重写器(Question Re-writer)
def initialize_question_rewriter():
    llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)

    system = """你是一个问题重写器,负责将输入的问题转换为更好的版本,以优化网络搜索的效果。
请查看输入的问题,并尝试推理其潜在的语义意图或含义。"""
    re_write_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            (
                "human",
                "这是初始问题:\n\n {question} \n 请制定一个改进后的问题。",
            ),
        ]
    )

    question_rewriter = re_write_prompt | llm | StrOutputParser()
    return question_rewriter

# 生成回答(Generate)
def initialize_rag_chain():
    prompt = hub.pull("rlm/crag-prompt")

    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)

    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    rag_chain = prompt | llm | StrOutputParser()
    return rag_chain

# 定义节点函数

def retrieve(state, retriever):
    """
    检索文档

    参数:
        state (dict): 当前图的状态
        retriever: 检索器对象

    返回:
        dict: 更新后的状态,包含检索到的文档
    """
    print("---检索---")
    question = state["question"]

    # 检索
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}

def generate(state, rag_chain):
    """
    生成回答

    参数:
        state (dict): 当前图的状态
        rag_chain: RAG生成链对象

    返回:
        dict: 更新后的状态,包含生成的回答
    """
    print("---生成回答---")
    question = state["question"]
    documents = state["documents"]

    # RAG生成
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

def grade_documents(state, retrieval_grader):
    """
    评估检索到的文档是否相关

    参数:
        state (dict): 当前图的状态
        retrieval_grader: 检索评分器对象

    返回:
        dict: 更新后的状态,包含过滤后的相关文档和是否进行网络搜索的标志
    """
    print("---检查文档与问题的相关性---")
    question = state["question"]
    documents = state["documents"]

    # 评分每个文档
    filtered_docs = []
    web_search = "No"
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            print("---评分:文档相关---")
            filtered_docs.append(d)
        else:
            print("---评分:文档不相关---")
            web_search = "Yes"
            continue
    return {"documents": filtered_docs, "question": question, "web_search": web_search}

def transform_query(state, question_rewriter):
    """
    优化问题

    参数:
        state (dict): 当前图的状态
        question_rewriter: 问题重写器对象

    返回:
        dict: 更新后的状态,包含优化后的问题
    """
    print("---优化问题---")
    question = state["question"]
    documents = state["documents"]

    # 重写问题
    better_question = question_rewriter.invoke({"question": question})
    print(f"优化后的问题:{better_question}")
    return {"documents": documents, "question": better_question}

def web_search(state, web_search_tool):
    """
    基于优化后的问题进行网络搜索

    参数:
        state (dict): 当前图的状态
        web_search_tool: 网络搜索工具对象

    返回:
        dict: 更新后的状态,包含追加的网络搜索结果
    """
    print("---网络搜索---")
    question = state["question"]
    documents = state["documents"]

    # 网络搜索
    docs = web_search_tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_result_doc = Document(page_content=web_results)
    documents.append(web_result_doc)

    return {"documents": documents, "question": question}

# 边函数

def decide_to_generate(state):
    """
    决定是否生成回答或重新优化问题

    参数:
        state (dict): 当前图的状态

    返回:
        str: 下一个节点的名称
    """
    print("---评估已评分的文档---")
    web_search = state["web_search"]

    if web_search == "Yes":
        # 所有文档均不相关,重新优化问题
        print("---决策:所有文档与问题不相关,优化问题---")
        return "transform_query"
    else:
        # 有相关文档,生成回答
        print("---决策:生成回答---")
        return "generate"

def grade_generation_v_documents_and_question(state, hallucination_grader, answer_grader):
    """
    评估生成的回答是否基于文档并解决了问题

    参数:
        state (dict): 当前图的状态
        hallucination_grader: 幻觉评分器对象
        answer_grader: 回答评分器对象

    返回:
        str: 下一个节点的名称
    """
    print("---检查幻觉---")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    score = hallucination_grader.invoke(
        {"documents": documents, "generation": generation}
    )
    grade = score.binary_score

    # 检查是否存在幻觉
    if grade == "yes":
        print("---决策:生成的回答基于文档---")
        # 检查回答是否解决了问题
        print("---评分生成的回答是否解决问题---")
        score = answer_grader.invoke({"question": question, "generation": generation})
        grade = score.binary_score
        if grade == "yes":
            print("---决策:生成的回答解决了问题---")
            return "useful"
        else:
            print("---决策:生成的回答未解决问题---")
            return "not useful"
    else:
        print("---决策:生成的回答未基于文档,重试---")
        return "not supported"

# 构建并编译图
def build_workflow(retrieve_fn, grade_documents_fn, generate_fn, transform_query_fn,
                  decide_to_generate_fn, grade_generation_fn,
                  web_search_fn,
                  retriever, rag_chain, retrieval_grader, hallucination_grader, answer_grader, question_rewriter, web_search_tool):
    workflow = StateGraph(GraphState)

    # 定义节点
    workflow.add_node("retrieve", lambda state: retrieve_fn(state, retriever))
    workflow.add_node("grade_documents", lambda state: grade_documents_fn(state, retrieval_grader))
    workflow.add_node("generate", lambda state: generate_fn(state, rag_chain))
    workflow.add_node("transform_query", lambda state: transform_query_fn(state, question_rewriter))
    workflow.add_node("web_search_node", lambda state: web_search_fn(state, web_search_tool))

    # 构建边
    workflow.add_edge(START, "retrieve")
    workflow.add_edge("retrieve", "grade_documents")
    workflow.add_conditional_edges(
        "grade_documents",
        decide_to_generate_fn,
        {
            "transform_query": "transform_query",
            "generate": "generate",
        },
    )
    workflow.add_edge("transform_query", "web_search_node")
    workflow.add_edge("web_search_node", "generate")
    workflow.add_edge("generate", END)

    # 编译图
    app = workflow.compile()
    return app

# 运行图
def run_workflow(app, inputs):
    for output in app.stream(inputs):
        for key, value in output.items():
            # 打印每个节点的状态
            pprint(f"节点 '{key}':")
            # 可选:打印每个节点的详细状态
            # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
        pprint("\n---\n")
    
    # 打印最终生成的回答
    pprint(value.get("generation", "没有生成回答"))

def main():
    # 初始化组件
    retriever = initialize_retriever()
    retrieval_grader = initialize_retrieval_grader()
    hallucination_grader = initialize_hallucination_grader()
    answer_grader = initialize_answer_grader()
    question_rewriter = initialize_question_rewriter()
    rag_chain = initialize_rag_chain()
    web_search_tool = TavilySearchResults(k=3)

    # 构建工作流
    app = build_workflow(
        retrieve_fn=retrieve,
        grade_documents_fn=grade_documents,
        generate_fn=generate,
        transform_query_fn=transform_query,
        decide_to_generate_fn=decide_to_generate,
        grade_generation_fn=grade_generation_v_documents_and_question,
        web_search_fn=web_search,
        retriever=retriever,
        rag_chain=rag_chain,
        retrieval_grader=retrieval_grader,
        hallucination_grader=hallucination_grader,
        answer_grader=answer_grader,
        question_rewriter=question_rewriter,
        web_search_tool=web_search_tool
    )

    # 运行第一个示例
    print("=== 示例 1 ===")
    inputs1 = {"question": "What are the types of agent memory?"}
    run_workflow(app, inputs1)

    # 运行第二个示例
    print("\n=== 示例 2 ===")
    inputs2 = {"question": "How does the AlphaCodium paper work?"}
    run_workflow(app, inputs2)

if __name__ == "__main__":
    main()

代码说明

  1. 环境设置
    • 使用 getpass 获取 OPENAI_API_KEYTAVILY_API_KEY 并设置为环境变量。
    • 请确保在运行脚本时输入有效的 API 密钥。
  2. 检索器初始化
    • 从指定的 URL 加载文档。
    • 使用 RecursiveCharacterTextSplitter 将文档分割成较小的片段。
    • 将文档片段存储到 Chroma 向量数据库中,以便高效检索。
  3. 评分器初始化
    • 检索评分器(Retrieval Grader):评估检索到的文档是否与问题相关。
    • 幻觉评分器(Hallucination Grader):评估生成的回答是否基于检索到的文档,避免幻觉。
    • 回答评分器(Answer Grader):评估生成的回答是否解决了用户的问题。
  4. 问题重写器初始化
    • 优化用户输入的问题,以便更好地进行网络搜索。
  5. 生成链初始化(Generate)
    • 使用 LangChain 的 hub.pull("rlm/crag-prompt") 获取预定义的 CRAG 提示语。
    • 配置 LLM 生成回答。
  6. 节点函数定义
    • retrieve:检索相关文档。
    • generate:基于检索到的文档生成回答。
    • grade_documents:评估检索到的文档是否相关。
    • transform_query:优化用户的问题。
    • web_search:基于优化后的问题进行网络搜索。
  7. 边函数定义
    • decide_to_generate:决定是生成回答还是优化问题。
    • grade_generation_v_documents_and_question:评估生成的回答是否基于文档并解决了问题。
  8. 图的构建与编译
    • 使用 LangGraph 的 StateGraph 定义工作流图。
    • 添加节点和边,定义流程控制逻辑。
    • 编译图为可执行的应用对象 app
  9. 运行图
    • 提供输入问题,运行整个工作流,生成并打印最终的回答。
    • 示例中包含两个问题进行演示。

执行示例

运行脚本后,您将看到如下类似的输出:

复制代码=== 示例 1 ===
---检索---
---检查文档与问题的相关性---
---评分:文档相关---
---评分:文档不相关---
---评分:文档相关---
---评分:文档相关---
---评估已评分的文档---
---决策:生成回答---
节点 'grade_documents':

---

---生成回答---
---检查幻觉---
---决策:生成的回答基于文档---
---评分生成的回答是否解决问题---
---决策:生成的回答解决了问题---
节点 'generate':

---
('Agents possess short-term memory, which is utilized for in-context learning, '
 'and long-term memory, allowing them to retain and recall vast amounts of '
 'information over extended periods. Some experts also classify working memory '
 'as a distinct type, although it can be considered a part of short-term '
 'memory in many cases.')

=== 示例 2 ===
---检索---
---检查文档与问题的相关性---
---评分:文档不相关---
---评分:文档不相关---
---评分:文档不相关---
---评分:文档相关---
---评估已评分的文档---
---决策:所有文档与问题不相关,优化问题---
---优化问题---
优化后的问题: How does the AlphaCodium paper work?
节点 'transform_query':

---
---网络搜索---
节点 'web_search_node':

---
---生成回答---
---检查幻觉---
---决策:生成的回答基于文档---
---评分生成的回答是否解决问题---
---决策:生成的回答解决了问题---
节点 'generate':

---
('The AlphaCodium paper functions by proposing a code-oriented iterative flow '
 'that involves repeatedly running and fixing generated code against '
 'input-output tests. Its key mechanisms include generating additional data '
 'like problem reflection and test reasoning to aid the iterative process, as '
 'well as enriching the code generation process. AlphaCodium aims to improve '
 'the performance of Large Language Models on code problems by following a '
 'test-based, multi-stage approach.')

注意事项

  • 包版本:请确保安装的包版本与代码兼容,尤其是 langchainlanggraph。如果遇到兼容性问题,请参考相应包的官方文档进行调整。
  • Prompt 模板:代码中使用了 hub.pull("rlm/crag-prompt") 来获取 CRAG 提示语。请确保该提示语存在于 LangChain 的 Hub 中,或者根据需要自定义提示语。
  • 错误处理:为了简化代码示例,未添加详细的错误处理逻辑。在实际应用中,建议添加适当的异常处理,以提高代码的鲁棒性。
  • LangSmith:代码中提到了 LangSmith,用于调试和监控 LangGraph 项目。如果需要使用,请参考 LangSmith 官方文档 进行配置。

评估(Eval)

在本节中,我们将评估使用LangGraph实现的自我纠正RAG系统与基线方法(Context Stuffing)的性能对比。

1. 导入必要模块

import langsmith
from langsmith.schemas import Example, Run
from langsmith.evaluation import evaluate

2. 克隆公共数据集

克隆一个公共的LCEL问题数据集,用于评估:

client = langsmith.Client()

# 克隆数据集到您的租户
try:
    public_dataset = (
        "https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d"
    )
    client.clone_public_dataset(public_dataset)
except:
    print("Please setup LangSmith")

解释:

  • clone_public_dataset:将公共数据集克隆到您的LangSmith租户中,以便进行评估。
  • public_dataset:指定要克隆的数据集URL。

3. 定义自定义评估器

创建两个评估器,用于检查生成的回答是否正确导入和执行。

def check_import(run: Run, example: Example) -> dict:
    """检查导入语句是否正确"""
    imports = run.outputs.get("imports")
    try:
        exec(imports)
        return {"key": "import_check", "score": 1}
    except Exception:
        return {"key": "import_check", "score": 0}

def check_execution(run: Run, example: Example) -> dict:
    """检查代码块是否能正确执行"""
    imports = run.outputs.get("imports")
    code = run.outputs.get("code")
    try:
        exec(imports + "\n" + code)
        return {"key": "code_execution_check", "score": 1}
    except Exception:
        return {"key": "code_execution_check", "score": 0}

解释:

  • check_import:尝试执行导入语句,如果成功,返回分数1;否则,返回分数0。
  • check_execution:尝试执行导入语句和代码块,如果成功,返回分数1;否则,返回分数0。

4. 定义预测函数

定义两个预测函数,分别用于基线方法和自我纠正RAG方法。

def predict_base_case(example: dict):
    """基线方法:Context Stuffing"""
    solution = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": [("user", example["question"])]}
    )
    return {"imports": solution.imports, "code": solution.code}

def predict_langgraph(example: dict):
    """自我纠正RAG方法"""
    graph = app.invoke(
        {"question": example["question"], "generation": "", "web_search": "No", "documents": []}
    )
    solution = graph["generation"]
    return {"imports": solution.imports, "code": solution.code}

解释:

  • predict_base_case:使用基线方法(Context Stuffing)生成回答。
  • predict_langgraph:使用自我纠正RAG方法生成回答。

5. 运行评估

使用LangSmith的evaluate函数,分别评估基线方法和自我纠正RAG方法的性能。

# 评估器列表
code_evaluator = [check_import, check_execution]

# 数据集名称
dataset_name = "lcel-teacher-eval"

# 运行基线方法的评估
try:
    experiment_results_ = evaluate(
        predict_base_case,
        data=dataset_name,
        evaluators=code_evaluator,
        experiment_prefix=f"test-without-langgraph-{llm.model}",
        max_concurrency=2,
        metadata={
            "llm": llm.model,
        },
    )
except:
    print("Please setup LangSmith")

# 运行自我纠正RAG方法的评估
try:
    experiment_results = evaluate(
        predict_langgraph,
        data=dataset_name,
        evaluators=code_evaluator,
        experiment_prefix=f"test-with-langgraph-{llm.model}-{flag}",
        max_concurrency=2,
        metadata={
            "llm": llm.model,
            "feedback": flag,
        },
    )
except:
    print("Please setup LangSmith")

解释:

  • evaluate:运行评估,传入预测函数、数据集、评估器列表和其他配置参数。
    • predict_base_case:基线方法的预测函数。
    • predict_langgraph:自我纠正RAG方法的预测函数。
    • code_evaluator:评估器列表,用于检查回答的导入和执行情况。
    • experiment_prefix:定义实验的前缀,便于区分不同的实验结果。
    • metadata:附加的元数据,用于记录LLM类型和反馈标志。

6. 结果

根据评估结果,自我纠正RAG方法的表现优于基线方法,特别是在添加重试机制后性能有所提升。然而,反思机制并未带来预期的改进,反而在某些情况下导致性能下降。此外,使用GPT-4模型的性能优于Claude3模型。

结果摘要:

  • 自我纠正RAG优于基线方法(CRAG outperforms base case):添加重试机制显著提高了性能。
  • 反思机制未带来改进(Reflection did not help):在重试前进行反思反而导致性能下降,相比之下,直接将错误反馈给LLM更为有效。
  • GPT-4优于Claude3(GPT-4 outperforms Claude3):GPT-4模型在执行工具调用时的错误率较低,表现优于Claude3模型。

您可以通过访问以下链接查看详细的评估结果:

评估结果链接


总结

通过本节的讲解,您已经学习了如何使用LangGraph实现一个自我纠正的RAG系统。这个系统能够通过自我反思和评分机制,优化检索和生成过程,确保生成的回答既相关又准确。具体来说,您已经掌握了以下内容:

  1. 系统设置:安装必要的包,配置API密钥,并设置LangSmith进行开发和监控。
  2. 索引创建:使用WebBaseLoaderChroma创建检索工具,索引并检索相关文档。
  3. LLM配置
    • 使用OpenAI的GPT-3.5进行路由、评分和生成。
    • 定义Pydantic模型来结构化存储生成的回答和评分结果。
    • 构建路由器、评分器和生成链。
  4. 状态管理:定义图的状态结构,包括用户问题、生成的回答、是否进行网络搜索和相关文档列表。
  5. 图定义
    • 定义检索、生成、评分和重写的节点。
    • 定义条件边路由,决定流程的执行顺序。
  6. 评估
    • 使用LangSmith的评估功能,比较自我纠正RAG方法与基线方法的性能。
    • 通过自定义评估器检查回答的准确性和相关性。

下一步建议:

  • 扩展功能:可以进一步扩展系统,如增加更多的单元测试,集成更多的工具或优化重试和反思机制。
  • 优化路由逻辑:根据评估结果,优化路由器的决策逻辑,提高系统的鲁棒性和生成回答的质量。
  • 多模型集成:结合不同的LLM模型,探索多模型协作的可能性,进一步提升回答的准确性和效率。
  • 部署和监控:将系统部署到生产环境中,并使用LangSmith进行持续的监控和优化,确保系统稳定运行。

通过这个示例,您已经掌握了使用LangGraph构建复杂的AI应用的核心概念和实践技巧,为进一步开发更高级的AI系统奠定了坚实的基础。如果您在实践过程中遇到任何问题,或有任何疑问,欢迎随时提问!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值