AI - RAG中的状态化管理聊天记录
大家好,今天我们来聊聊LangChain和LLM中一个重要的话题——状态化管理聊天记录。在使用大语言模型(LLM)的时候,聊天记录(History)和状态(State)管理是非常关键的。那我们先从为什么需要状态和历史记录讲起,然后再聊聊如何保留聊天记录。
为什么需要状态和历史记录
- 保持上下文:我们聊天时是有上下文的,比如你刚问我“什么是状态化管理”,下一个问题可能是“怎么实现”,这些问题之间是有联系的。如果没有上下文,LLM每次回答都会像是第一次见到你,回答可能就会前后不一致。
- 个性化体验:有了历史记录,我们可以根据用户之前的对话内容来做个性化的回复。这就像是我们和朋友之间的对话,大家了解彼此的喜好和习惯,交流起来更顺畅。
- 追踪用户意图:管理聊天状态可以帮助我们更好地理解用户当下的意图。例如,用户可能在连续的问题中逐渐明确他们的需求,通过记录这些对话历史,我们能够更准确地提供帮助。
怎么保留聊天记录呢
-
内存保存:最简单的方法就是将历史记录保存在内存中。这种方式适用于短时间的对话,因为内存是有限的,长时间或者大量用户会耗尽内存。实现方法很简单,我们可以用一个列表来存储历史记录。
chat_history = [] # 用列表来保存聊天记录 def add_message_to_history(user_message, bot_response): chat_history.append({ "user": user_message, "bot": bot_response})
-
文件或数据库保存:对于需要长时间保存或者需要跨会话保留历史记录的情况,我们可以选择将聊天记录保存到文件或者数据库中。文件保存可以用简单的JSON,数据库可以用SQLite或者更复杂的数据库系统。
import json def save_history_to_file(chat_history, filename="chat_history.json"): with open(filename, "w") as file: json.dump(chat_history, file) def load_history_from_file(filename="chat_history.json"): with open(filename, "r") as file: return json.load(file)
-
状态图管理:在LangChain中,我们可以用状态图(StateGraph)来管理复杂的聊天状态。状态图允许我们定义不同的状态节点,如查询、检索、生成回复等,并设置它们之间的转换条件。这种方式灵活且强大,适用于复杂的对话场景管理。
from langgraph.graph import StateGraph, MessagesState state_graph = StateGraph(MessagesState) # 初始化状态图 # 定义各个状态节点 state_graph.add_node(query_or_respond) state_graph.add_node(tools) state_graph.add_node(generate) # 设置节点之间的条件和转换 state_graph.set_entry_point("query_or_respond") state_graph.add_conditional_edges( "query_or_respond", tools_condition, { END: END, "tools": "tools"}, ) state_graph.add_edge("tools", "generate") state_graph.add_edge("generate", END)
如何用代码来实现呢
接下来我们详细讲解如何一步步实现这样的功能。备注:对于本文中的代码片段,主体来源于LangChain官网,有兴趣的读者可以去官网查看。
首先导入了一些必要的库和模块:
import os # 导入操作系统模块,用来设置环境变量
from langchain_openai import ChatOpenAI # OpenAI 聊天模型类
from langchain_openai import OpenAIEmbeddings # OpenAI 嵌入向量类
from langchain_core.vectorstores import InMemoryVectorStore # 内存向量存储类
import bs4 # BeautifulSoup 库,用于网页解析
from langchain import hub # langchain 的 hub 模块
from langchain_community.document_loaders import WebBaseLoader # 加载网页内容的 class
from langchain_core.documents import Document # 文档类
from langchain_text_splitters import RecursiveCharacterTextSplitter # 文本切分工具
from langgraph.graph import START, StateGraph, MessagesState # 状态图相关模块
from typing_extensions import List, TypedDict # Python 类型扩展
from langchain_core.tools import tool # 工具装饰器
from langchain_core.messages import SystemMessage # 系统消息类
from langgraph.graph import END # 状态图中的结束节点
from langgraph.prebuilt import ToolNode, tools_condition # 预先构建的工具节点和工具条件
from langgraph.checkpoint.memory import MemorySaver # 内存保存器
然后,我们设置一个环境变量来存放 OpenAI 的 API Key:
os.environ["OPENAI_API_KEY"] = 'your-api-key' # 设置 OpenAI API Key
接下来,我们初始化一些主要组件,包括嵌入模型、内存向量存储和语言模型:
embeddings = OpenAIEmbeddings(model="text-embedding-3-large") # 初始化 OpenAI 嵌入模型
vector_store = InMemoryVectorStore(embeddings) # 创建一个内存向量存储
llm = ChatOpenAI(model="gpt-4o-mini") # 初始化聊天模型
下一步是从一个指定的网址加载博客内容,并将其切分成小块:
# 加载并切分博客内容
loader = WebBaseLoader(
web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",), # 指定要加载的博客的网址
bs_kwargs=dict( # BeautifulSoup 的参数,仅解析指定的类
parse_only=bs4.SoupStrainer(
class_=("post-content", "post-title", "post-header")
)
),
)
docs = loader.load() # 加载文档内容
# 设置文本切分器,指定每块大小为1000字符,重叠200字符
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(docs) # 切分文档
然后,我们将切分后的文档块添加到向量存储中进行索引:
# 索引这些文档块
_ = vector_store.add_documents(documents=all_splits) # 将文档块添加到向量存储中
接着,我们定义一个名为 retrieve
的工具函数,用于根据用户查询检索信息:
@tool(response_format="content_and_artifact")
def retrieve(query: str):
"""Retrieve information related to a query."""
retrieved_docs = vector_store.similarity_search(query, k=2) # 搜索与查询最相似的2个文档
serialized = "\n\n".join(
(f"Source: {
doc.metadata}\n" f"Content: {
doc.page_content}")
for doc in retrieved_docs
)
return serialized, retrieved_docs # 返回序列化的内容和检索到的文档
然后我们定义三个步骤的函数来处理用户的查询和生成回答:
# 第一步:生成包含工具调用的 AI 消息并发送
def query_or_respond(state: MessagesState):
"""生成检索的工具调用或响应。"""
llm_with_tools = llm.bind_tools([retrieve]) # 绑定 retrieve 工具与聊天模型
response = llm_with_tools.invoke(state["messages"]) # 生成响应
# MessagesState 会附加消息到状态而不是覆盖
return {
"messages": [response]} # 返回消息状态
# 第二步:执行检索
tools = ToolNode([retrieve]) # 创建工具节点
# 第三步:使用检索内容生成响应
def generate(state: MessagesState):
"""生成答案。"""
# 获取生成的工具消息
recent_tool_messages = []
for message in reversed(state["messages"]):
if message.type == "tool":
recent_tool_messages.append(message)
else:
break
tool_messages = recent_tool_messages[::-1] # 反向排序
# 格式化为提示
docs_content = "\n\n".join(doc.content for doc in tool_messages)
system_message_content = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
f"{
docs_content}"
)
conversation_messages = [
message
for message in state["messages"]
if message.type in ("human", "system")
or (message.type == "ai" and not message.tool_calls)
]
prompt = [SystemMessage(system_message_content)] + conversation_messages
# 运行
response = llm.invoke(prompt)
return {
"messages": [response]} # 返回消息状态
然后我们创建一个状态图并设置各个节点和连接:
graph_builder = StateGraph(MessagesState) # 创建状态图构建器
graph_builder.add_node(query_or_respond) # 添加 query_or_respond 节点
graph_builder.add_node(tools) # 添加 tools 节点
graph_builder.add_node(generate) # 添加 generate 节点
graph_builder.set_entry_point("query_or_respond") # 设置入口点为 query_or_respond
graph_builder.add_conditional_edges(
"query_or_respond",
tools_condition,
{
END: END, "tools": "tools"},
)
graph_builder.add_edge("tools", "generate") # 从 tools 到 generate 的连接
graph_builder.add_edge("generate", END) # 从 generate 到 END 的连接
memory = MemorySaver() # 创建内存保存器
graph = graph_builder.compile(checkpointer=memory) # 编译状态图,使用内存保存器
最后,我们设置一个线程ID,并模拟两个问题的问答过程:
# 指定线程 ID
config = {
"configurable": {
"thread_id": "abc123"}}
# 输入第一个问题
input_message = "What is Task Decomposition?"
for step in graph.stream(
{
"messages": [{
"role": "user", "content": input_message}]},
stream_mode="values",
config=config,
):
step["messages"][-1].pretty_print() # 打印最后一条消息内容
# 输入第二个问题
input_message = "Can you look up some common ways of doing it?"
for step in graph.stream(
{
"messages": [{
"role": "user", "content": input_message}]},
stream_mode="values",
config=config,
):
step["messages"][-1].pretty_print() # 打印最后一条消息内容
总的来说,这段代码是一个完整的流程,用于从网页加载文档、切分文档、存储文档向量,然后使用这些数据来回应用户的查询。下面是代码输出:
================================ Human Message =================================
What is Task Decomposition?
================================== Ai Message ==================================
Tool Calls:
retrieve (call_pUlHd3ysUAh2666YBKXL75XX)
Call ID: call_pUlHd3ysUAh2666YBKXL75XX
Args:
query: Task Decomposition
================================= Tool Message ============&