前言
在这篇文章中,将为大家介绍可能是目前最强大的AI Agent设计框架,集多种规划和反思技术的集大成者,LATS。文章内容会相对比较复杂难懂,值得收藏和反复研读。
LATS的概念
LATS,全称是Language Agent Tree Search,说的更直白一些,LATS = Tree search + ReAct + Plan&Execute + Reflexion。这么来看,LATS确实非常高级和复杂,下面我们根据上面的等式,先从宏观上拆解一下LATS。
1. Tree Search
Tree Search是一种树搜索算法,LATS 使用蒙特卡罗树搜索(MCTS)算法,通过平衡探索和利用,找到最优决策路径。
蒙特卡罗方法可能大家都比较熟悉了,是一种通过随机采样模拟来求解问题的方法。通过生成随机数,建立概率模型,以解决难以通过其他方法解决的数值问题。蒙特卡罗方法的一个典型应用是求定积分。假设我们要计算函数 f(x) 在[a, b]之间的积分,即阴影部分面积。
蒙特卡罗方法的解法如下:在[a, b]之间取一个随机数 x,用 f(x)⋅(b−a) 来估计阴影部分的面积。为了提高估计精度,可以取多个随机数 x,然后取这些估计值的平均值作为最终结果。当取的随机数 x 越多,结果将越准确,估计值将越接近真实值。
蒙特卡罗树搜索(MCTS)则是一种基于树结构的蒙特卡罗方法。它在整个 2^N(N 为决策次数,即树深度)空间中进行启发式搜索,通过反馈机制寻找最优路径。MCTS 的五个主要核心部分是:
-
树结构:每一个叶子节点到根节点的路径都对应一个解,解空间大小为 2^N。
-
蒙特卡罗方法:通过随机统计方法获取观测结果,驱动搜索过程。
-
损失评估函数:设计一个可量化的损失函数,提供反馈评估解的优劣。
-
反向传播线性优化:采用反向传播对路径上的所有节点进行优化。
-
启发式搜索策略:遵循损失最小化原则,在整个搜索空间上进行启发式搜索。
MCTS 的每个循环包括四个步骤:
-
选择(Selection):从根节点开始,按照最大化某种启发式价值选择子节点,直到到达叶子节点。使用上置信区间算法(UCB)选择子节点。
-
扩展(Expansion):如果叶子节点不是终止节点,扩展该节点,添加一个或多个子节点。
-
仿真(Simulation):从新扩展的节点开始,进行随机模拟,直到到达终止状态。
-
反向传播(Backpropagation):将模拟结果沿着路径反向传播,更新每个节点的统计信息。
2. ReAct
ReAct的概念和设计模式,在此前已做过详细介绍。
它的典型流程如下图所示,可以用一个有趣的循环来描述:思考(Thought)→ 行动(Action)→ 观察(Observation),简称TAO循环。
-
思考(Thought):面对一个问题,我们需要进行深入的思考。这个思考过程是关于如何定义问题、确定解决问题所需的关键信息和推理步骤。
-
行动(Action):确定了思考的方向后,接下来就是行动的时刻。根据我们的思考,采取相应的措施或执行特定的任务,以期望推动问题向解决的方向发展。
-
观察(Observation):行动之后,我们必须仔细观察结果。这一步是检验我们的行动是否有效,是否接近了问题的答案。
-
循环迭代
3. Plan & Execute
Plan-and-Execute这个方法的本质是先计划再执行,即先把用户的问题分解成一个个的子任务,然后再执行各个子任务,并根据执行情况调整计划。
4. Reflexion
Reflexion的本质是Basic Reflection加上强化学习,完整的Reflexion框架由三个部分组成:
-
参与者(Actor):根据状态观测量生成文本和动作。参与者在环境中采取行动并接受观察结果,从而形成轨迹。前文所介绍的Reflexion Agent,其实指的就是这一块
-
评估者(Evaluator):对参与者的输出进行评价。具体来说,它将生成的轨迹(也被称作短期记忆)作为输入并输出奖励分数。根据人物的不同,使用不同的奖励函数(决策任务使用LLM和基于规则的启发式奖励)。
-
自我反思(Self-Reflection):这个角色由大语言模型承担,能够为未来的试验提供宝贵的反馈。自我反思模型利用奖励信号、当前轨迹和其持久记忆生成具体且相关的反馈,并存储在记忆组件中。智能体利用这些经验(存储在长期记忆中)来快速改进决策。
因此,融合了Tree Search、ReAct、Plan & Execute、Reflexion的能力于一身之后,LATS成为AI Agent设计模式中,集反思模式和规划模式的大成者。
LATS的工作流程
LATS的工作流程如下图所示,包括以下步骤:
-
选择 (Selection):即从根节点开始,使用上置信区树 (UCT) 算法选择具有最高 UCT 值的子节点进行扩展。
-
扩展 (Expansion):通过从预训练语言模型 (LM) 中采样 n 个动作扩展树,接收每个动作并返回反馈,然后增加 n 个新的子节点。
-
评估 (Evaluation):为每个新子节点分配一个标量值,以指导搜索算法前进,LATS 通过 LM 生成的评分和自一致性得分设计新的价值函数。
-
模拟 (Simulation):扩展当前选择的节点直到达到终端状态,优先选择最高价值的节点。
-
回溯 (Backpropagation):根据轨迹结果更新树的值,路径中的每个节点的值被更新以反映模拟结果。
-
反思 (Reflection):在遇到不成功的终端节点时,LM 生成自我反思,总结过程中的错误并提出改进方案。这些反思和失败轨迹在后续迭代中作为额外上下文整合,帮助提高模型的表现。
下图是在langchain中实现LATS的过程:
第一步,选择:根据下面步骤中的总奖励选择最佳的下一步行动,如果找到解决方案或达到最大搜索深度,做出响应;否则就继续搜索。
第二步,扩展和执行:生成N个潜在操作,并且并行执行。
第三步,反思和评估:观察行动的结果,并根据反思和外部反馈对决策评分。
第四步,反向传播:根据结果更新轨迹的分数。
LATS的实现过程
下面,风叔通过实际的源码,详细介绍LATS模式的实现方法,具体的源代码地址可以在文末获取。
第一步 构建树节点
LATS 基于蒙特卡罗树搜索。对于每个搜索步骤,它都会选择具有最高“置信上限”的节点,这是一个平衡开发(最高平均奖励)和探索(最低访问量)的指标。从该节点开始,它会生成 N(在本例中为 5)个新的候选操作,并将它们添加到树中。当它生成有效解决方案或达到最大次数(搜索树深度)时,会停止搜索。
在Node节点中,我们定义了几个关键的函数:
-
best_child:选择 UCT 最高的子项进行下一步搜索
-
best_child_score:返回具有最高价值的子项
-
height:检查已经推进的树的深度
-
upper_confidence_bound:返回 UCT 分数,平衡分支的探索与利用
-
backpropogate:利用反向传播,更新此节点及其父节点的分数
-
get_trajectory:获取代表此搜索分支的消息
-
get_best_solution:返回当前子树中的最佳解决方案
import math``from collections import deque``from typing import Optional``from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage`` ``class Node:` `def __init__(` `self,` `messages: list[BaseMessage],` `reflection: Reflection,` `parent: Optional[Node] = None,` `):` `self.messages = messages` `self.parent = parent` `self.children = []` `self.value = 0` `self.visits = 0` `self.reflection = reflection` `self.depth = parent.depth + 1 if parent is not None else 1` `self._is_solved = reflection.found_solution if reflection else False` `if self._is_solved:` `self._mark_tree_as_solved()` `self.backpropagate(reflection.normalized_score)`` ` `def __repr__(self) -> str:` `return (` `f"<Node value={self.value}, visits={self.visits},"` `f" solution={self.messages} reflection={self.reflection}/>"` `)`` ` `@property` `def is_solved(self):` `"""If any solutions exist, we can end the search."""` `return self._is_solved`` ` `@property` `def is_terminal(self):` `return not self.children`` ` `@property` `def best_child(self):` `"""Select the child with the highest UCT to search next."""` `if not self.children:` `return None` `all_nodes = self._get_all_children()` `return max(all_nodes, key=lambda child: child.upper_confidence_bound())`` ` `@property` `def best_child_score(self):` `"""Return the child with the highest value."""` `if not self.children:` `return None` `return max(self.children, key=lambda child: int(child.is_solved) * child.value)`` ` `@property` `def height(self) -> int:` `"""Check for how far we've rolled out the tree."""` `if self.children:` `return 1 + max([child.height for child in self.children])` `return 1`` ` `def upper_confidence_bound(self, exploration_weight=1.0):` `"""Return the UCT score. This helps balance exploration vs. exploitation of a branch."""` `if self.parent is None:` `raise ValueError("Cannot obtain UCT from root node")` `if self.visits == 0:` `return self.value` `# Encourages exploitation of high-value trajectories` `average_reward = self.value / self.visits` `# Encourages exploration of less-visited trajectories` `exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)` `return average_reward + exploration_weight * exploration_term`` ` `def backpropagate(self, reward: float):` `"""Update the score of this node and its parents."""` `node = self` `while node:` `node.visits += 1` `node.value = (node.value * (node.visits - 1) + reward) / node.visits` `node = node.parent`` ` `def get_messages(self, include_reflections: bool = True):` `if include_reflections:` `return self.messages + [self.reflection.as_message()]` `return self.messages`` ` `def get_trajectory(self, include_reflections: bool = True) -> list[BaseMessage]:` `"""Get messages representing this search branch."""` `messages = []` `node = self` `while node:` `messages.extend(` `node.get_messages(include_reflections=include_reflections)[::-1]` `)` `node = node.parent` `# Reverse the final back-tracked trajectory to return in the correct order` `return messages[::-1] # root solution, reflection, child 1, ...`` ` `def _get_all_children(self):` `all_nodes = []` `nodes = deque()` `nodes.append(self)` `while nodes:` `node = nodes.popleft()` `all_nodes.extend(node.children)` `for n in node.children:` `nodes.append(n)` `return all_nodes`` ` `def get_best_solution(self):` `"""Return the best solution from within the current sub-tree."""` `all_nodes = [self] + self._get_all_children()` `best_node = max(` `all_nodes,` `# We filter out all non-terminal, non-solution trajectories` `key=lambda node: int(node.is_terminal and node.is_solved) * node.value,` `)` `return best_node`` ` `def _mark_tree_as_solved(self):` `parent = self.parent` `while parent:` `parent._is_solved = True` `parent = parent.parent
第二步 构建Agent
Agent将主要处理三个事项:
-
反思:根据工具执行响应的结果打分
-
初始响应:创建根节点,并开始搜索
-
扩展:从当前树中的最佳位置,生成5个候选的下一步
对于更多实际的应用,比如代码生成,可以将代码执行结果集成到反馈或奖励中,这种外部反馈对Agent效果的提升将非常有用。
对于Agent,首先构建工具Tools,我们只使用了一个搜索引擎工具。
from langchain_community.tools.tavily_search import TavilySearchResults``from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper``from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation`` ``search = TavilySearchAPIWrapper()``tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)``tools = [tavily_tool]``tool_executor = ToolExecutor(tools=tools)
然后,构建反射系统,反射系统将根据决策和工具使用结果,对Agent的输出进行打分,我们将在其他两个节点中调用此方法。
class Reflection(BaseModel):` `reflections: str = Field(` `description="The critique and reflections on the sufficiency, superfluency,"` `" and general quality of the response"` `)` `score: int = Field(` `description="Score from 0-10 on the quality of the candidate response.",` `gte=0,` `lte=10,` `)` `found_solution: bool = Field(` `description="Whether the response has fully solved the question or task."` `)`` ` `def as_message(self):` `return HumanMessage(` `content=f"Reasoning: {self.reflections}\nScore: {self.score}"` `)`` ` `@property` `def normalized_score(self) -> float:` `return self.score / 10.0`` ``prompt = ChatPromptTemplate.from_messages(` `[` `(` `"system",` `"Reflect and grade the assistant response to the user question below.",` `),` `("user", "{input}"),` `MessagesPlaceholder(variable_name="candidate"),` `]``)`` ``reflection_llm_chain = (` `prompt` `| llm.bind_tools(tools=[Reflection], tool_choice="Reflection").with_config(` `run_name="Reflection"` `)` `| PydanticToolsParser(tools=[Reflection])``)`` ``@as_runnable``def reflection_chain(inputs) -> Reflection:` `tool_choices = reflection_llm_chain.invoke(inputs)` `reflection = tool_choices[0]` `if not isinstance(inputs["candidate"][-1], AIMessage):` `reflection.found_solution = False` `return reflection
接下来,我们从根节点开始,根据用户输入进行响应
from langchain_core.prompt_values import ChatPromptValue``from langchain_core.runnables import RunnableConfig`` ``prompt_template = ChatPromptTemplate.from_messages(` `[` `(` `"system",` `"You are an AI assistant.",` `),` `("user", "{input}"),` `MessagesPlaceholder(variable_name="messages", optional=True),` `]``)`` ``initial_answer_chain = prompt_template | llm.bind_tools(tools=tools).with_config(` `run_name="GenerateInitialCandidate"``)`` ``parser = JsonOutputToolsParser(return_id=True)``initial_response = initial_answer_chain.invoke(` `{"input": "Write a research report on lithium pollution."}``)
然后开始根节点,我们将候选节点生成和reflection打包到单个节点中。
import json`` ``# Define the node we will add to the graph``def generate_initial_response(state: TreeState) -> dict:` `"""Generate the initial candidate response."""` `res = initial_answer_chain.invoke({"input": state["input"]})` `parsed = parser.invoke(res)` `tool_responses = tool_executor.batch(` `[ToolInvocation(tool=r["type"], tool_input=r["args"]) for r in parsed]` `)` `output_messages = [res] + [` `ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])` `for resp, tool_call in zip(tool_responses, parsed)` `]` `reflection = reflection_chain.invoke(` `{"input": state["input"], "candidate": output_messages}` `)` `root = Node(output_messages, reflection=reflection)` `return {` `**state,` `"root": root,` `}
第三步 生成候选节点
对于每个节点,生成5个待探索的候选节点。
def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):` `n = config["configurable"].get("N", 5)` `bound_kwargs = llm.bind_tools(tools=tools).kwargs` `chat_result = llm.generate(` `[messages.to_messages()],` `n=n,` `callbacks=config["callbacks"],` `run_name="GenerateCandidates",` `**bound_kwargs,` `)` `return [gen.message for gen in chat_result.generations[0]]`` ``expansion_chain = prompt_template | generate_candidates``res = expansion_chain.invoke({"input": "Write a research report on lithium pollution."})
将候选节点生成和refleciton步骤打包在下面的扩展节点中,所有操作都以批处理的方式进行,以加快执行速度。
from collections import defaultdict`` ``def expand(state: TreeState, config: RunnableConfig) -> dict:` `"""Starting from the "best" node in the tree, generate N candidates for the next step."""` `root = state["root"]` `best_candidate: Node = root.best_child if root.children else root` `messages = best_candidate.get_trajectory()` `# Generate N candidates from the single child candidate` `new_candidates = expansion_chain.invoke(` `{"input": state["input"], "messages": messages}, config` `)` `parsed = parser.batch(new_candidates)` `flattened = [` `(i, tool_call)` `for i, tool_calls in enumerate(parsed)` `for tool_call in tool_calls` `]` `tool_responses = tool_executor.batch(` `[` `ToolInvocation(tool=tool_call["type"], tool_input=tool_call["args"])` `for _, tool_call in flattened` `]` `)` `collected_responses = defaultdict(list)` `for (i, tool_call), resp in zip(flattened, tool_responses):` `collected_responses[i].append(` `ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])` `)` `output_messages = []` `for i, candidate in enumerate(new_candidates):` `output_messages.append([candidate] + collected_responses[i])`` ` `# Reflect on each candidate` `# For tasks with external validation, you'd add that here.` `reflections = reflection_chain.batch(` `[{"input": state["input"], "candidate": msges} for msges in output_messages],` `config,` `)` `# Grow tree` `child_nodes = [` `Node(cand, parent=best_candidate, reflection=reflection)` `for cand, reflection in zip(output_messages, reflections)` `]` `best_candidate.children.extend(child_nodes)` `# We have already extended the tree directly, so we just return the state` `return state
第四步 构建流程图
下面,我们构建流程图,将根节点和扩展节点加入进来
from typing import Literal`` ``from langgraph.graph import END, StateGraph, START`` ``def should_loop(state: TreeState) -> Literal["expand", "__end__"]:` `"""Determine whether to continue the tree search."""` `root = state["root"]` `if root.is_solved:` `return END` `if root.height > 5:` `return END` `return "expand"`` `` ``builder = StateGraph(TreeState)``builder.add_node("start", generate_initial_response)``builder.add_node("expand", expand)``builder.add_edge(START, "start")`` ``builder.add_conditional_edges(` `"start",` `# Either expand/rollout or finish` `should_loop,``)``builder.add_conditional_edges(` `"expand",` `# Either continue to rollout or finish` `should_loop,``)`` ``graph = builder.compile()
至此,整个LATS的核心逻辑就介绍完了。大家可以关注公众号【风叔云】,回复关键词【LATS源码】,获取LATS设计模式的完整源代码。
总结
与其他基于树的方法相比,LATS实现了自我反思的推理步骤,显著提升了性能。当采取行动后,LATS不仅利用环境反馈,还结合来自语言模型的反馈,以判断推理中是否存在错误并提出替代方案。这种自我反思的能力与其强大的搜索算法相结合,使得LATS更适合处理一些相对复杂的任务。
然而,由于算法本身的复杂性以及涉及的反思步骤,LATS通常比其他单智能体方法使用更多的计算资源,并且完成任务所需的时间更长。
如何学习大模型 AI ?
由于新岗位的生产效率,要优于被取代岗位的生产效率,所以实际上整个社会的生产效率是提升的。
但是具体到个人,只能说是:
“最先掌握AI的人,将会比较晚掌握AI的人有竞争优势”。
这句话,放在计算机、互联网、移动互联网的开局时期,都是一样的道理。
我在一线互联网企业工作十余年里,指导过不少同行后辈。帮助很多人得到了学习和成长。
我意识到有很多经验和知识值得分享给大家,也可以通过我们的能力和经验解答大家在人工智能学习中的很多困惑,所以在工作繁忙的情况下还是坚持各种整理和分享。但苦于知识传播途径有限,很多互联网行业朋友无法获得正确的资料得到学习提升,故此将并将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。
第一阶段(10天):初阶应用
该阶段让大家对大模型 AI有一个最前沿的认识,对大模型 AI 的理解超过 95% 的人,可以在相关讨论时发表高级、不跟风、又接地气的见解,别人只会和 AI 聊天,而你能调教 AI,并能用代码将大模型和业务衔接。
- 大模型 AI 能干什么?
- 大模型是怎样获得「智能」的?
- 用好 AI 的核心心法
- 大模型应用业务架构
- 大模型应用技术架构
- 代码示例:向 GPT-3.5 灌入新知识
- 提示工程的意义和核心思想
- Prompt 典型构成
- 指令调优方法论
- 思维链和思维树
- Prompt 攻击和防范
- …
第二阶段(30天):高阶应用
该阶段我们正式进入大模型 AI 进阶实战学习,学会构造私有知识库,扩展 AI 的能力。快速开发一个完整的基于 agent 对话机器人。掌握功能最强的大模型开发框架,抓住最新的技术进展,适合 Python 和 JavaScript 程序员。
- 为什么要做 RAG
- 搭建一个简单的 ChatPDF
- 检索的基础概念
- 什么是向量表示(Embeddings)
- 向量数据库与向量检索
- 基于向量检索的 RAG
- 搭建 RAG 系统的扩展知识
- 混合检索与 RAG-Fusion 简介
- 向量模型本地部署
- …
第三阶段(30天):模型训练
恭喜你,如果学到这里,你基本可以找到一份大模型 AI相关的工作,自己也能训练 GPT 了!通过微调,训练自己的垂直大模型,能独立训练开源多模态大模型,掌握更多技术方案。
到此为止,大概2个月的时间。你已经成为了一名“AI小子”。那么你还想往下探索吗?
- 为什么要做 RAG
- 什么是模型
- 什么是模型训练
- 求解器 & 损失函数简介
- 小实验2:手写一个简单的神经网络并训练它
- 什么是训练/预训练/微调/轻量化微调
- Transformer结构简介
- 轻量化微调
- 实验数据集的构建
- …
第四阶段(20天):商业闭环
对全球大模型从性能、吞吐量、成本等方面有一定的认知,可以在云端和本地等多种环境下部署大模型,找到适合自己的项目/创业方向,做一名被 AI 武装的产品经理。
- 硬件选型
- 带你了解全球大模型
- 使用国产大模型服务
- 搭建 OpenAI 代理
- 热身:基于阿里云 PAI 部署 Stable Diffusion
- 在本地计算机运行大模型
- 大模型的私有化部署
- 基于 vLLM 部署大模型
- 案例:如何优雅地在阿里云私有部署开源大模型
- 部署一套开源 LLM 项目
- 内容安全
- 互联网信息服务算法备案
- …
学习是一个过程,只要学习就会有挑战。天道酬勤,你越努力,就会成为越优秀的自己。
如果你能在15天内完成所有的任务,那你堪称天才。然而,如果你能完成 60-70% 的内容,你就已经开始具备成为一名大模型 AI 的正确特征了。