3. langgraph中的Tool Calling (How to handle large numbers of tools)

1. 工具定义

import re
import uuid

from langchain_core.tools import StructuredTool

def create_tool(company: str) -> dict:
    """Create schema for a placeholder tool."""
    # Remove non-alphanumeric characters and replace spaces with underscores for the tool name
    formatted_company = re.sub(r"[^\w\s]", "", company).replace(" ", "_")

    def company_tool(year: int) -> str:
        # Placeholder function returning static revenue information for the company and year
        return f"{company} had revenues of $100 in {year}."

    return StructuredTool.from_function(
        company_tool,
        name=formatted_company,
        description=f"Information about {company}",
    )

# Abbreviated list of S&P 500 companies for demonstration
s_and_p_500_companies = [
    "3M",
    "A.O. Smith",
    "Abbott",
    "Accenture",
    "Advanced Micro Devices",
    "Yum! Brands",
    "Zebra Technologies",
    "Zimmer Biomet",
    "Zoetis",
]

# Create a tool for each company and store it in a registry with a unique UUID as the key
tool_registry = {
    str(uuid.uuid4()): create_tool(company) for company in s_and_p_500_companies
}
list(tool_registry.values())
[StructuredTool(name='3M', description='Information about 3M', args_schema=<class 'langchain_core.utils.pydantic.3M'>, func=<function create_tool.<locals>.company_tool at 0x000001CE50CE19E0>),
 StructuredTool(name='AO_Smith', description='Information about A.O. Smith', args_schema=<class 'langchain_core.utils.pydantic.AO_Smith'>, func=<function create_tool.<locals>.company_tool at 0x000001CE532894E0>),
 StructuredTool(name='Abbott', description='Information about Abbott', args_schema=<class 'langchain_core.utils.pydantic.Abbott'>, func=<function create_tool.<locals>.company_tool at 0x000001CE52ED62A0>),
 StructuredTool(name='Accenture', description='Information about Accenture', args_schema=<class 'langchain_core.utils.pydantic.Accenture'>, func=<function create_tool.<locals>.company_tool at 0x000001CE5328ACA0>),
 StructuredTool(name='Advanced_Micro_Devices', description='Information about Advanced Micro Devices', args_schema=<class 'langchain_core.utils.pydantic.Advanced_Micro_Devices'>, func=<function create_tool.<locals>.company_tool at 0x000001CE53289580>),
 StructuredTool(name='Yum_Brands', description='Information about Yum! Brands', args_schema=<class 'langchain_core.utils.pydantic.Yum_Brands'>, func=<function create_tool.<locals>.company_tool at 0x000001CE5328B880>),
 StructuredTool(name='Zebra_Technologies', description='Information about Zebra Technologies', args_schema=<class 'langchain_core.utils.pydantic.Zebra_Technologies'>, func=<function create_tool.<locals>.company_tool at 0x000001CE532C53A0>),
 StructuredTool(name='Zimmer_Biomet', description='Information about Zimmer Biomet', args_schema=<class 'langchain_core.utils.pydantic.Zimmer_Biomet'>, func=<function create_tool.<locals>.company_tool at 0x000001CE52EF0680>),
 StructuredTool(name='Zoetis', description='Information about Zoetis', args_schema=<class 'langchain_core.utils.pydantic.Zoetis'>, func=<function create_tool.<locals>.company_tool at 0x000001CE53289EE0>)]

2. 将工具存入向量数据库中

from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import OpenAIEmbeddings

tool_documents = [
    Document(
        page_content=tool.description,
        id=id,
        metadata={"tool_name": tool.name},
    )
    for id, tool in tool_registry.items()
]

from langchain_community.embeddings import ZhipuAIEmbeddings
embed = ZhipuAIEmbeddings(
    model="Embedding-3",
    api_key="your api key",
)
vector_store = InMemoryVectorStore(embedding=embed)
document_ids = vector_store.add_documents(tool_documents)

3. 无工具选择节点的graph

from typing import Annotated

from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict

from langgraph.graph import StateGraph, START
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition


# Define the state structure using TypedDict.
# It includes a list of messages (processed by add_messages)
# and a list of selected tool IDs.
class State(TypedDict):
    messages: Annotated[list, add_messages]
    selected_tools: list[str]


builder = StateGraph(State)

# Retrieve all available tools from the tool registry.
tools = list(tool_registry.values())
llm = ChatOpenAI(
    temperature=0,
    model="GLM-4-plus",
    openai_api_key="your api key",
    openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
)


# The agent function processes the current state
# by binding selected tools to the LLM.
def agent(state: State):
    # Map tool IDs to actual tools
    # based on the state's selected_tools list.
    selected_tools = [tool_registry[id] for id in state["selected_tools"]]
    # Bind the selected tools to the LLM for the current interaction.
    llm_with_tools = llm.bind_tools(selected_tools)
    # Invoke the LLM with the current messages and return the updated message list.
    return {"messages": [llm_with_tools.invoke(state["messages"])]}


# The select_tools function selects tools based on the user's last message content.
def select_tools(state: State):
    last_user_message = state["messages"][-1]
    query = last_user_message.content
    tool_documents = vector_store.similarity_search(query)
    return {"selected_tools": [document.id for document in tool_documents]}


builder.add_node("agent", agent)
builder.add_node("select_tools", select_tools)

tool_node = ToolNode(tools=tools)
builder.add_node("tools", tool_node)

builder.add_conditional_edges("agent", tools_condition, path_map=["tools", "__end__"])
builder.add_edge("tools", "agent")
builder.add_edge("select_tools", "agent")
builder.add_edge(START, "select_tools")
graph = builder.compile()

graph 可视化

from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

请添加图片描述

测试

user_input = "Can you give me some information about AMD in 2022?"

result = graph.invoke({"messages": [("user", user_input)]})
print(result["selected_tools"])
['cde4f2b3-52d4-4b01-8a27-3894126cce96', '64fede49-2c54-439d-991c-42b0bcda2897', 'e912c52f-7661-4611-bfeb-f28cae7cac93', '1887be17-47b6-424a-8236-a031441fb4c5']
for message in result["messages"]:
    message.pretty_print()
================================[1m Human Message [0m=================================

Can you give me some information about AMD in 2022?
==================================[1m Ai Message [0m==================================
Tool Calls:
  Advanced_Micro_Devices (call_-9187874028535886692)
 Call ID: call_-9187874028535886692
  Args:
    year: 2022
=================================[1m Tool Message [0m=================================
Name: Advanced_Micro_Devices

Advanced Micro Devices had revenues of $100 in 2022.
==================================[1m Ai Message [0m==================================

In 2022, Advanced Micro Devices (AMD) reported revenues of $100. If you need more detailed information or specific aspects about AMD's performance or activities during that year, feel free to ask!

4. 具备工具选择节点的graph

from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langgraph.pregel.retry import RetryPolicy

from pydantic import BaseModel, Field


llm = ChatOpenAI(
    temperature=0,
    model="GLM-4-flash",
    openai_api_key="your api key",
    openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
)

class QueryForTools(BaseModel):
    """Generate a query for additional tools."""

    query: str = Field(..., description="Query for additional tools.")

def select_tools(state: State):
    """Selects tools based on the last message in the conversation state.

    If the last message is from a human, directly uses the content of the message
    as the query. Otherwise, constructs a query using a system message and invokes
    the LLM to generate tool suggestions.
    """
    last_message = state["messages"][-1]
    hack_remove_tool_condition = False  # Simulate an error in the first tool selection

    if isinstance(last_message, HumanMessage):
        query = last_message.content
        hack_remove_tool_condition = True  # Simulate wrong tool selection
    else:
        assert isinstance(last_message, ToolMessage)
        system = SystemMessage(
            "Given this conversation, generate a query for additional tools. "
            "The query should be a short string containing what type of information "
            "is needed. If no further information is needed, "
            "set more_information_needed False and populate a blank string for the query."
        )
        input_messages = [system] + state["messages"]
        response = llm.bind_tools([QueryForTools], tool_choice=True).invoke(
            input_messages
        )
        query = response.tool_calls[0]["args"]["query"]

    # Search the tool vector store using the generated query
    tool_documents = vector_store.similarity_search(query)
    if hack_remove_tool_condition:
        # Simulate error by removing the correct tool from the selection
        selected_tools = [
            document.id
            for document in tool_documents
            if document.metadata["tool_name"] != "Advanced_Micro_Devices"
        ]
    else:
        selected_tools = [document.id for document in tool_documents]
    print(f"selected_tools:{selected_tools}")
    return {"selected_tools": selected_tools}


graph_builder = StateGraph(State)
graph_builder.add_node("agent", agent)
graph_builder.add_node("select_tools", select_tools, retry=RetryPolicy(max_attempts=3))

tool_node = ToolNode(tools=tools)
graph_builder.add_node("tools", tool_node)

graph_builder.add_conditional_edges(
    "agent",
    tools_condition,
)
graph_builder.add_edge("tools", "select_tools")
graph_builder.add_edge("select_tools", "agent")
graph_builder.add_edge(START, "select_tools")
graph = graph_builder.compile()

graph可视化

from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

请添加图片描述

示例

user_input = "Can you give me some information about AMD in 2022? use the tool to check."

result = graph.invoke({"messages": [("user", user_input)]})
selected_tools:['e912c52f-7661-4611-bfeb-f28cae7cac93', '64fede49-2c54-439d-991c-42b0bcda2897', '05c2a6eb-0d1a-45a0-8f2f-af581ec738c4']
selected_tools:['cde4f2b3-52d4-4b01-8a27-3894126cce96', '64fede49-2c54-439d-991c-42b0bcda2897', '05c2a6eb-0d1a-45a0-8f2f-af581ec738c4', 'e912c52f-7661-4611-bfeb-f28cae7cac93']
selected_tools:['cde4f2b3-52d4-4b01-8a27-3894126cce96', '1887be17-47b6-424a-8236-a031441fb4c5', '05c2a6eb-0d1a-45a0-8f2f-af581ec738c4', '64fede49-2c54-439d-991c-42b0bcda2897']
for message in result["messages"]:
    message.pretty_print()
================================[1m Human Message [0m=================================

Can you give me some information about AMD in 2022? use the tool to check.
==================================[1m Ai Message [0m==================================
Tool Calls:
  3M (call_-9187867294025044519)
 Call ID: call_-9187867294025044519
  Args:
    year: 2022
=================================[1m Tool Message [0m=================================
Name: 3M

3M had revenues of $100 in 2022.
==================================[1m Ai Message [0m==================================
Tool Calls:
  Advanced_Micro_Devices (call_-9187869871006560462)
 Call ID: call_-9187869871006560462
  Args:
    year: 2022
=================================[1m Tool Message [0m=================================
Name: Advanced_Micro_Devices

Advanced Micro Devices had revenues of $100 in 2022.
==================================[1m Ai Message [0m==================================

In 2022, Advanced Micro Devices (AMD) had revenues of $100.

这里模型选择的是model=“GLM-4-flash”, 但是如果是model=“GLM-4-plus”, 就会出错。

参考链接:https://langchain-ai.github.io/langgraph/how-tos/many-tools/#incorporating-with-an-agent

如果有任何问题,欢迎在评论区提问。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值