AI - RAG中的状态化管理聊天记录

AI - RAG中的状态化管理聊天记录

大家好,今天我们来聊聊LangChain和LLM中一个重要的话题——状态化管理聊天记录。在使用大语言模型(LLM)的时候,聊天记录(History)和状态(State)管理是非常关键的。那我们先从为什么需要状态和历史记录讲起,然后再聊聊如何保留聊天记录。
ai-langchain-rag

为什么需要状态和历史记录

  1. 保持上下文:我们聊天时是有上下文的,比如你刚问我“什么是状态化管理”,下一个问题可能是“怎么实现”,这些问题之间是有联系的。如果没有上下文,LLM每次回答都会像是第一次见到你,回答可能就会前后不一致。
  2. 个性化体验:有了历史记录,我们可以根据用户之前的对话内容来做个性化的回复。这就像是我们和朋友之间的对话,大家了解彼此的喜好和习惯,交流起来更顺畅。
  3. 追踪用户意图:管理聊天状态可以帮助我们更好地理解用户当下的意图。例如,用户可能在连续的问题中逐渐明确他们的需求,通过记录这些对话历史,我们能够更准确地提供帮助。

怎么保留聊天记录呢

  1. 内存保存:最简单的方法就是将历史记录保存在内存中。这种方式适用于短时间的对话,因为内存是有限的,长时间或者大量用户会耗尽内存。实现方法很简单,我们可以用一个列表来存储历史记录。

    chat_history = []  # 用列表来保存聊天记录
    
    def add_message_to_history(user_message, bot_response):
        chat_history.append({
         "user": user_message, "bot": bot_response})
    
  2. 文件或数据库保存:对于需要长时间保存或者需要跨会话保留历史记录的情况,我们可以选择将聊天记录保存到文件或者数据库中。文件保存可以用简单的JSON,数据库可以用SQLite或者更复杂的数据库系统。

    import json
    
    def save_history_to_file(chat_history, filename="chat_history.json"):
        with open(filename, "w") as file:
            json.dump(chat_history, file)
    
    def load_history_from_file(filename="chat_history.json"):
        with open(filename, "r") as file:
            return json.load(file)
    
  3. 状态图管理:在LangChain中,我们可以用状态图(StateGraph)来管理复杂的聊天状态。状态图允许我们定义不同的状态节点,如查询、检索、生成回复等,并设置它们之间的转换条件。这种方式灵活且强大,适用于复杂的对话场景管理。

    from langgraph.graph import StateGraph, MessagesState
    
    state_graph = StateGraph(MessagesState)  # 初始化状态图
    
    # 定义各个状态节点
    state_graph.add_node(query_or_respond)
    state_graph.add_node(tools)
    state_graph.add_node(generate)
    
    # 设置节点之间的条件和转换
    state_graph.set_entry_point("query_or_respond")
    state_graph.add_conditional_edges(
        "query_or_respond",
        tools_condition,
        {
         END: END, "tools": "tools"},
    )
    state_graph.add_edge("tools", "generate")
    state_graph.add_edge("generate", END)
    

如何用代码来实现呢

接下来我们详细讲解如何一步步实现这样的功能。备注:对于本文中的代码片段,主体来源于LangChain官网,有兴趣的读者可以去官网查看。

首先导入了一些必要的库和模块:

import os  # 导入操作系统模块,用来设置环境变量
from langchain_openai import ChatOpenAI  # OpenAI 聊天模型类
from langchain_openai import OpenAIEmbeddings  # OpenAI 嵌入向量类
from langchain_core.vectorstores import InMemoryVectorStore  # 内存向量存储类
import bs4  # BeautifulSoup 库,用于网页解析
from langchain import hub  # langchain 的 hub 模块
from langchain_community.document_loaders import WebBaseLoader  # 加载网页内容的 class
from langchain_core.documents import Document  # 文档类
from langchain_text_splitters import RecursiveCharacterTextSplitter  # 文本切分工具
from langgraph.graph import START, StateGraph, MessagesState  # 状态图相关模块
from typing_extensions import List, TypedDict  # Python 类型扩展
from langchain_core.tools import tool  # 工具装饰器
from langchain_core.messages import SystemMessage  # 系统消息类
from langgraph.graph import END  # 状态图中的结束节点
from langgraph.prebuilt import ToolNode, tools_condition  # 预先构建的工具节点和工具条件
from langgraph.checkpoint.memory import MemorySaver  # 内存保存器

然后,我们设置一个环境变量来存放 OpenAI 的 API Key:

os.environ["OPENAI_API_KEY"] = 'your-api-key'  # 设置 OpenAI API Key

接下来,我们初始化一些主要组件,包括嵌入模型、内存向量存储和语言模型:

embeddings = OpenAIEmbeddings(model="text-embedding-3-large")  # 初始化 OpenAI 嵌入模型
vector_store = InMemoryVectorStore(embeddings)  # 创建一个内存向量存储

llm = ChatOpenAI(model="gpt-4o-mini")  # 初始化聊天模型

下一步是从一个指定的网址加载博客内容,并将其切分成小块:

# 加载并切分博客内容
loader = WebBaseLoader(
    web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),  # 指定要加载的博客的网址
    bs_kwargs=dict(  # BeautifulSoup 的参数,仅解析指定的类
        parse_only=bs4.SoupStrainer(
            class_=("post-content", "post-title", "post-header")
        )
    ),
)
docs = loader.load()  # 加载文档内容

# 设置文本切分器,指定每块大小为1000字符,重叠200字符
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(docs)  # 切分文档

然后,我们将切分后的文档块添加到向量存储中进行索引:

# 索引这些文档块
_ = vector_store.add_documents(documents=all_splits)  # 将文档块添加到向量存储中

接着,我们定义一个名为 retrieve 的工具函数,用于根据用户查询检索信息:

@tool(response_format="content_and_artifact")
def retrieve(query: str):
    """Retrieve information related to a query."""
    retrieved_docs = vector_store.similarity_search(query, k=2)  # 搜索与查询最相似的2个文档
    serialized = "\n\n".join(
        (f"Source: {
     doc.metadata}\n" f"Content: {
     doc.page_content}")
        for doc in retrieved_docs
    )
    return serialized, retrieved_docs  # 返回序列化的内容和检索到的文档

然后我们定义三个步骤的函数来处理用户的查询和生成回答:

# 第一步:生成包含工具调用的 AI 消息并发送
def query_or_respond(state: MessagesState):
    """生成检索的工具调用或响应。"""
    llm_with_tools = llm.bind_tools([retrieve])  # 绑定 retrieve 工具与聊天模型
    response = llm_with_tools.invoke(state["messages"])  # 生成响应
    # MessagesState 会附加消息到状态而不是覆盖
    return {
   "messages": [response]}  # 返回消息状态

# 第二步:执行检索
tools = ToolNode([retrieve])  # 创建工具节点

# 第三步:使用检索内容生成响应
def generate(state: MessagesState):
    """生成答案。"""
    # 获取生成的工具消息
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]  # 反向排序

    # 格式化为提示
    docs_content = "\n\n".join(doc.content for doc in tool_messages)
   
    system_message_content = (
        "You are an assistant for question-answering tasks. "
        "Use the following pieces of retrieved context to answer "
        "the question. If you don't know the answer, say that you "
        "don't know. Use three sentences maximum and keep the "
        "answer concise."
        "\n\n"
        f"{
     docs_content}"
    )
    conversation_messages = [
        message
        for message in state["messages"]
        if message.type in ("human", "system")
        or (message.type == "ai" and not message.tool_calls)
    ]
    prompt = [SystemMessage(system_message_content)] + conversation_messages

    # 运行
    response = llm.invoke(prompt)
    return {
   "messages": [response]}  # 返回消息状态

然后我们创建一个状态图并设置各个节点和连接:

graph_builder = StateGraph(MessagesState)  # 创建状态图构建器
graph_builder.add_node(query_or_respond)  # 添加 query_or_respond 节点
graph_builder.add_node(tools)  # 添加 tools 节点
graph_builder.add_node(generate)  # 添加 generate 节点

graph_builder.set_entry_point("query_or_respond")  # 设置入口点为 query_or_respond
graph_builder.add_conditional_edges(
    "query_or_respond",
    tools_condition,
    {
   END: END, "tools": "tools"},
)
graph_builder.add_edge("tools", "generate")  # 从 tools 到 generate 的连接
graph_builder.add_edge("generate", END)  # 从 generate 到 END 的连接

memory = MemorySaver()  # 创建内存保存器
graph = graph_builder.compile(checkpointer=memory)  # 编译状态图,使用内存保存器

最后,我们设置一个线程ID,并模拟两个问题的问答过程:

# 指定线程 ID
config = {
   "configurable": {
   "thread_id": "abc123"}}

# 输入第一个问题
input_message = "What is Task Decomposition?"
for step in graph.stream(
    {
   "messages": [{
   "role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    step["messages"][-1].pretty_print()  # 打印最后一条消息内容

# 输入第二个问题
input_message = "Can you look up some common ways of doing it?"
for step in graph.stream(
    {
   "messages": [{
   "role": "user", "content": input_message}]},
    stream_mode="values",
    config=config,
):
    step["messages"][-1].pretty_print()  # 打印最后一条消息内容

总的来说,这段代码是一个完整的流程,用于从网页加载文档、切分文档、存储文档向量,然后使用这些数据来回应用户的查询。下面是代码输出:

================================ Human Message =================================

What is Task Decomposition?
================================== Ai Message ==================================
Tool Calls:
  retrieve (call_pUlHd3ysUAh2666YBKXL75XX)
 Call ID: call_pUlHd3ysUAh2666YBKXL75XX
  Args:
    query: Task Decomposition
================================= Tool Message ============&
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值