[LangGraph教程]LangGraph03——为聊天机器人添加记忆

上一篇[LangGraph教程]LangGraph02——使用工具增强聊天机器人中,我们为聊天机器人增加了调用外部工具的功能,但机器人现在还不能记得之前回答的上下文,这限制了其进行连贯的多轮对话的能力。

LangGraph 通过持久性检查点机制解决了这个问题。具体来说,如果在编译图时提供了一个 checkpointer,并在调用图时提供了一个 thread_id,LangGraph 会在每个步骤后自动保存当前状态。这样一来,当使用相同的 thread_id 再次调用图时,图会自动加载之前保存的状态,从而使聊天机器人能够从上次停止的地方继续对话。

这种检查点机制不仅支持多轮对话,还具备更强大的功能。例如,它可以用于错误恢复、人机协作工作流程、时间旅行交互等场景。

使用工具的机器人

代码如下,详细解释在[LangGraph教程]LangGraph02——使用工具增强聊天机器人

from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph

# 1. 定义状态类
class State(TypedDict):
    messages:Annotated[list, add_messages]
# 2. 构建图    
graph_builder = StateGraph(State)

# 3.定义工具节点 
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_openai import ChatOpenAI

local_llm = ["deepseek-r1:8b","qwen2.5:latest"]

llm = ChatOpenAI(model=local_llm[1], temperature=0.0, api_key="ollama", base_url="http://localhost:11434/v1")

tool = TavilySearchResults(max_results=2)
tools = [tool]
llm_with_tools = llm.bind_tools(tools)

def chatbot(state: State):
    return {"messages": [llm_with_tools.invoke(state["messages"])]}

# 4. 把节点加入图 
# 第一个参数是节点名称,第二个是要调用的函数
from langgraph.prebuilt import ToolNode

graph_builder.add_node("chatbot", chatbot)

tool_node = ToolNode(tools=[tool])
graph_builder.add_node("tools", tool_node)

# 5. 把边加入图
from langgraph.prebuilt import tools_condition

graph_builder.add_conditional_edges(
    "chatbot",
    tools_condition,
)
graph_builder.add_edge("tools", "chatbot")
graph_builder.set_entry_point("chatbot")

graph = graph_builder.compile()

graph.invoke({"messages":"今天武汉天气怎么样"})

from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))

如何实现多轮对话

前面的步骤和之前的一样,这里不再赘述:

from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph

# 1. 定义状态类
class State(TypedDict):
    messages:Annotated[list, add_messages]
# 2. 构建图    
graph_builder = StateGraph(State)

# 3.定义工具节点 
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_openai import ChatOpenAI

local_llm = ["deepseek-r1:8b","qwen2.5:latest"]

llm = ChatOpenAI(model=local_llm[1], temperature=0.0, api_key="ollama", base_url="http://localhost:11434/v1")

tool = TavilySearchResults(max_results=2)
tools = [tool]
llm_with_tools = llm.bind_tools(tools)

def chatbot(state: State):
    return {"messages": [llm_with_tools.invoke(state["messages"])]}

# 4. 把节点加入图 
# 第一个参数是节点名称,第二个是要调用的函数
from langgraph.prebuilt import ToolNode

graph_builder.add_node("chatbot", chatbot)

tool_node = ToolNode(tools=[tool])
graph_builder.add_node("tools", tool_node)

# 5. 把边加入图
from langgraph.prebuilt import tools_condition

graph_builder.add_conditional_edges(
    "chatbot",
    tools_condition,
)
graph_builder.add_edge("tools", "chatbot")
graph_builder.set_entry_point("chatbot")

在编译图时提供 checkpointer:checkpointer 是一个用于保存和加载状态的对象。它可以是一个简单的文件系统实现,也可以是一个更复杂的数据库实现。这里将检查点存在内存中,用MemorySaver实现

from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()

graph = graph_builder.compile(checkpointer=memory)

在调用图时提供 thread_id:thread_id 是一个唯一标识符,用于区分不同的对话线程。通过使用相同的 thread_id,LangGraph 可以加载之前保存的状态,从而实现多轮对话。

这里配置了线程id为1

user_input = "Hi there! My name is Will."

config = {"configurable": {"thread_id": "1"}}

events = graph.stream(
    {"messages": [{"role": "user", "content": user_input}]},
    config,
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()

输出如下:

================================Human Message=================================

Hi there! My name is Will.
==================================Ai Message==================================

Hello Will! How can I assist you today?

在每个步骤后,LangGraph 会自动调用 checkpointer 保存当前状态。当再次使用相同的 thread_id 调用图时,LangGraph 会自动加载之前保存的状态。

现在我们再次询问之前提到过的名称,看模型是否记得:

user_input = "Remember my name?"


events = graph.stream(
    {"messages": [{"role": "user", "content": user_input}]},
    config,
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()

回答如下:

================================Human Message=================================

Remember my name?
==================================Ai Message==================================

Of course! You mentioned your name is Will. How can I help you today, Will?

从回答中可以看到模型正确回答了这个问题。

现在,让我们换个线程id,看看模型是否记得:

# The only difference is we change the `thread_id` here to "2" instead of "1"
events = graph.stream(
    {"messages": [{"role": "user", "content": user_input}]},
    {"configurable": {"thread_id": "2"}},
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()

从输出中可以看到,模型回答的并不对,说明不同线程之间不共享上下文

================================Human Message=================================

Remember my name?
==================================Ai Message==================================

Of course! You can always refer to me as Qwen. How can I assist you today?

最后,可以用get_state函数查看当前的上下文

snapshot = graph.get_state(config)
snapshot

结果如下:

StateSnapshot(
    values={
        "messages": [
            HumanMessage(
                content="Hi there! My name is Will.",
                additional_kwargs={},
                response_metadata={},
                id="bb8fc848-d778-4edb-86ed-af6c1b1b5cde",
            ),
            HumanMessage(
                content="Hi there! My name is Will.",
                additional_kwargs={},
                response_metadata={},
                id="a38b9310-bab0-4838-b9ca-ef5656015012",
            ),
            AIMessage(
                content="Hello Will! How can I assist you today?",
                additional_kwargs={"refusal": None},
                response_metadata={
                    "token_usage": {
                        "completion_tokens": 11,
                        "prompt_tokens": 203,
                        "total_tokens": 214,
                        "completion_tokens_details": None,
                        "prompt_tokens_details": None,
                    },
                    "model_name": "qwen2.5:latest",
                    "system_fingerprint": "fp_ollama",
                    "id": "chatcmpl-423",
                    "finish_reason": "stop",
                    "logprobs": None,
                },
                id="run-67c235d5-76ad-4287-a1bd-bb2a4729206f-0",
                usage_metadata={
                    "input_tokens": 203,
                    "output_tokens": 11,
                    "total_tokens": 214,
                    "input_token_details": {},
                    "output_token_details": {},
                },
            ),
            HumanMessage(
                content="Remember my name?",
                additional_kwargs={},
                response_metadata={},
                id="36e412ef-daa0-422e-b93c-25b9236985bc",
            ),
            AIMessage(
                content="Of course! You mentioned your name is Will. How can I help you today, Will?",
                additional_kwargs={"refusal": None},
                response_metadata={
                    "token_usage": {
                        "completion_tokens": 20,
                        "prompt_tokens": 227,
                        "total_tokens": 247,
                        "completion_tokens_details": None,
                        "prompt_tokens_details": None,
                    },
                    "model_name": "qwen2.5:latest",
                    "system_fingerprint": "fp_ollama",
                    "id": "chatcmpl-217",
                    "finish_reason": "stop",
                    "logprobs": None,
                },
                id="run-649f5f0c-c2d2-4cb0-951d-8c62a56e717a-0",
                usage_metadata={
                    "input_tokens": 227,
                    "output_tokens": 20,
                    "total_tokens": 247,
                    "input_token_details": {},
                    "output_token_details": {},
                },
            ),
        ]
    },
    next=(),
    config={
        "configurable": {
            "thread_id": "1",
            "checkpoint_ns": "",
            "checkpoint_id": "1f020170-d033-62b8-8006-dbfcb29ec140",
        }
    },
    metadata={
        "source": "loop",
        "writes": {
            "chatbot": {
                "messages": [
                    AIMessage(
                        content="Of course! You mentioned your name is Will. How can I help you today, Will?",
                        additional_kwargs={"refusal": None},
                        response_metadata={
                            "token_usage": {
                                "completion_tokens": 20,
                                "prompt_tokens": 227,
                                "total_tokens": 247,
                                "completion_tokens_details": None,
                                "prompt_tokens_details": None,
                            },
                            "model_name": "qwen2.5:latest",
                            "system_fingerprint": "fp_ollama",
                            "id": "chatcmpl-217",
                            "finish_reason": "stop",
                            "logprobs": None,
                        },
                        id="run-649f5f0c-c2d2-4cb0-951d-8c62a56e717a-0",
                        usage_metadata={
                            "input_tokens": 227,
                            "output_tokens": 20,
                            "total_tokens": 247,
                            "input_token_details": {},
                            "output_token_details": {},
                        },
                    )
                ]
            }
        },
        "step": 6,
        "parents": {},
        "thread_id": "1",
    },
    created_at="2025-04-23T07:46:22.313541+00:00",
    parent_config={
        "configurable": {
            "thread_id": "1",
            "checkpoint_ns": "",
            "checkpoint_id": "1f020170-b584-64d5-8005-f78ed589b2f8",
        }
    },
    tasks=(),
)

完整代码

from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph

# 1. 定义状态类
class State(TypedDict):
    messages:Annotated[list, add_messages]
# 2. 构建图    
graph_builder = StateGraph(State)

# 3.定义工具节点 
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_openai import ChatOpenAI

local_llm = ["deepseek-r1:8b","qwen2.5:latest"]

llm = ChatOpenAI(model=local_llm[1], temperature=0.0, api_key="ollama", base_url="http://localhost:11434/v1")

tool = TavilySearchResults(max_results=2)
tools = [tool]
llm_with_tools = llm.bind_tools(tools)

def chatbot(state: State):
    return {"messages": [llm_with_tools.invoke(state["messages"])]}

# 4. 把节点加入图 
# 第一个参数是节点名称,第二个是要调用的函数
from langgraph.prebuilt import ToolNode

graph_builder.add_node("chatbot", chatbot)

tool_node = ToolNode(tools=[tool])
graph_builder.add_node("tools", tool_node)

# 5. 把边加入图
from langgraph.prebuilt import tools_condition

graph_builder.add_conditional_edges(
    "chatbot",
    tools_condition,
)
graph_builder.add_edge("tools", "chatbot")
graph_builder.set_entry_point("chatbot")

from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()

graph = graph_builder.compile(checkpointer=memory)

user_input = "Hi there! My name is Will."

config = {"configurable": {"thread_id": "1"}}

events = graph.stream(
    {"messages": [{"role": "user", "content": user_input}]},
    config,
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()

user_input = "Remember my name?"

events = graph.stream(
    {"messages": [{"role": "user", "content": user_input}]},
    config,
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()

# The only difference is we change the `thread_id` here to "2" instead of "1"
events = graph.stream(
    {"messages": [{"role": "user", "content": user_input}]},
    {"configurable": {"thread_id": "2"}},
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()

snapshot = graph.get_state(config)
print(snapshot)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值