ChatGLM3 tool_registry.py 代码解析

ChatGLM3 tool_registry.py 代码解析

0. 背景

学习 ChatGLM3 的项目内容,过程中使用 AI 代码工具,对代码进行解释,帮助自己快速理解代码。这篇文章记录 ChatGLM3 tool_registry.py 的代码解析内容。

1. tool_registry.py

from copy import deepcopy
import inspect
from pprint import pformat
import traceback
from types import GenericAlias
from typing import get_origin, Annotated

_TOOL_HOOKS = {}
_TOOL_DESCRIPTIONS = {}

这段代码定义了几个全局变量和导入了一些模块。让我来逐个解释:

  • from copy import deepcopy:从 copy 模块导入 deepcopy 函数,用于深拷贝对象。

  • import inspect:导入 inspect 模块,用于获取对象的信息。

  • from pprint import pformat:从 pprint 模块导入 pformat 函数,用于格式化打印对象。

  • import traceback:导入 traceback 模块,用于打印异常堆栈信息。

  • from types import GenericAlias:从 types 模块导入 GenericAlias 类,用于表示泛型类型。

  • from typing import get_origin, Annotated:从 typing 模块导入 get_origin 和 Annotated 函数,用于获取泛型类型的原始类型和注解信息。

  • _TOOL_HOOKS = {}:定义一个空的全局字典变量 _TOOL_HOOKS,用于存储工具的钩子函数。

  • _TOOL_DESCRIPTIONS = {}:定义一个空的全局字典变量 _TOOL_DESCRIPTIONS,用于存储工具的描述信息。

这段代码的作用可能是为后续的工具注册和存储钩子函数以及描述信息提供了一个全局的数据结构。

def register_tool(func: callable):
    tool_name = func.__name__
    tool_description = inspect.getdoc(func).strip()
    python_params = inspect.signature(func).parameters
    tool_params = []
    for name, param in python_params.items():
        annotation = param.annotation
        if annotation is inspect.Parameter.empty:
            raise TypeError(f"Parameter `{name}` missing type annotation")
        if get_origin(annotation) != Annotated:
            raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
        
        typ, (description, required) = annotation.__origin__, annotation.__metadata__
        typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
        if not isinstance(description, str):
            raise TypeError(f"Description for `{name}` must be a string")
        if not isinstance(required, bool):
            raise TypeError(f"Required for `{name}` must be a bool")

        tool_params.append({
            "name": name,
            "description": description,
            "type": typ,
            "required": required
        })
    tool_def = {
        "name": tool_name,
        "description": tool_description,
        "params": tool_params
    }

    print("[registered tool] " + pformat(tool_def))
    _TOOL_HOOKS[tool_name] = func
    _TOOL_DESCRIPTIONS[tool_name] = tool_def

    return func

这段代码定义了一个名为 register_tool 的函数,该函数接受一个可调用对象 func 作为参数。

以下是代码的详细解析:

  • tool_name = func.name:获取传入函数 func 的名称,并将其赋值给变量 tool_name。
  • tool_description = inspect.getdoc(func).strip():使用 inspect.getdoc 函数获取传入函数 func 的文档字符串,并去除首尾的空白字符,将结果赋值给变量 tool_description。
  • python_params = inspect.signature(func).parameters:使用 - inspect.signature 函数获取传入函数 func 的参数签名,并将其参数信息保存在变量 python_params 中。
  • tool_params = []:创建一个空列表 tool_params,用于存储工具的参数信息。
  • for name, param in python_params.items()::遍历 python_params 中的每个参数项,其中 name 是参数名,param 是参数对象。
    • annotation = param.annotation:获取参数对象的注解,并将其赋值给变量 annotation。
    • if annotation is inspect.Parameter.empty::如果注解为空,则表示参数缺少类型注解,抛出 TypeError 异常。
    • if get_origin(annotation) != Annotated::如果注解类型不是 typing.Annotated,抛出 TypeError 异常。
    • typ, (description, required) = annotation.origin, annotation.metadata:从注解中获取类型和元数据信息,并将其分别赋值给 typ、description 和 required 变量。
    • typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.name:将类型转换为字符串,如果类型是 GenericAlias(泛型类型),则将其转换为字符串形式,否则保留类型的名称。
    • if not isinstance(description, str)::如果描述不是字符串类型,抛出 TypeError 异常。
    • if not isinstance(required, bool)::如果 required 不是布尔类型,抛出 TypeError 异常。
    • tool_params.append({…}):将参数的名称、描述、类型和是否必需组成的字典添加到 tool_params 列表中。
  • tool_def = {…}:创建一个字典 tool_def,包含工具的名称、描述和参数信息。
  • print("[registered tool] " + pformat(tool_def)):打印注册的工具的定义,使用 pformat 函数格式化输出。
  • _TOOL_HOOKS[tool_name] = func:将函数 func 添加到全局字典变量 _TOOL_HOOKS 中,键为工具的名称。
  • _TOOL_DESCRIPTIONS[tool_name] = tool_def:将工具的定义添加到全局字典变量 _TOOL_DESCRIPTIONS 中,键为工具的名称。
  • return func:返回原始的函数 func。
    这段代码的作用是将传入的函数作为工具进行注册,并将工具的名称、描述和参数信息存储在全局字典变量中。注册的工具可以通过 _TOOL_HOOKS 全局字典变量进行调用。
def dispatch_tool(tool_name: str, tool_params: dict) -> str:
    if tool_name not in _TOOL_HOOKS:
        return f"Tool `{tool_name}` not found. Please use a provided tool."
    tool_call = _TOOL_HOOKS[tool_name]
    try:
        ret = tool_call(**tool_params)  
    except:
        ret = traceback.format_exc()
    return str(ret)

def get_tools() -> dict:
    return deepcopy(_TOOL_DESCRIPTIONS)

这段代码定义了两个函数:dispatch_tool 和 get_tools。让我为你逐个解释:

函数 dispatch_tool(tool_name: str, tool_params: dict) -> str:

该函数接受两个参数 tool_name 和 tool_params,并返回一个字符串。

  • if tool_name not in _TOOL_HOOKS::检查传入的工具名称 tool_name 是否存在于全局字典变量 _TOOL_HOOKS 中。
    如果工具名称不存在,返回一个提示字符串,表示找不到该工具。
  • tool_call = _TOOL_HOOKS[tool_name]:从全局字典变量 _TOOL_HOOKS 中获取与工具名称对应的工具函数,并将其赋值给变量 tool_call。
  • try::尝试执行工具函数,并捕获可能的异常。
    • ret = tool_call(**tool_params):使用传入的参数 tool_params 调用工具函数,并将返回值赋值给变量 ret。这里使用 ** 运算符将 tool_params 字典解包为关键字参数传递给工具函数。
  • except::捕获可能的异常。
    • ret = traceback.format_exc():如果出现异常,将异常的堆栈信息格式化为字符串,并将其赋值给变量 ret。
  • return str(ret):返回结果,无论是工具函数的返回值还是异常堆栈信息,都将转换为字符串并返回。
    该函数的作用是根据传入的工具名称和参数调用对应的工具函数,并返回结果或异常信息的字符串表示。

函数 get_tools() -> dict:

该函数不接受任何参数,返回一个字典。

  • return deepcopy(_TOOL_DESCRIPTIONS):返回全局字典变量 _TOOL_DESCRIPTIONS 的深拷贝。
    该函数的作用是返回全局字典变量 _TOOL_DESCRIPTIONS 的副本,以提供工具的名称、描述和参数信息。

这两个函数一起提供了工具的调度和获取工具信息的功能。dispatch_tool 函数用于调用具体的工具函数,而 get_tools 函数用于获取所有已注册工具的描述信息。

deepcopy: deepcopy 是一个函数,用于创建一个对象的深拷贝。深拷贝是指创建一个新对象,将原始对象的所有元素递归地复制到新对象中,包括嵌套的对象。换句话说,它会创建一个原始对象的完全独立副本,而不仅仅是引用原始对象的内存地址。
深拷贝对于需要完全独立的副本的情况非常有用,尤其是在处理可变对象时。通过深拷贝,可以确保修改一个对象的副本不会影响到原始对象,因为它们是相互独立的。
例如,假设有一个包含嵌套列表和字典的对象 obj,如果直接对 obj 进行赋值操作,那么新对象将只是原始对象的引用,而不是副本。这意味着对新对象的修改也会反映到原始对象中。但是,如果使用 deepcopy 函数创建一个新对象 new_obj,那么 new_obj 将是 obj 的深拷贝副本,对 new_obj 的修改不会影响到 obj。

@register_tool
def random_number_generator(
    seed: Annotated[int, 'The random seed used by the generator', True], 
    range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
) -> int:
    """
    Generates a random number x, s.t. range[0] <= x < range[1]
    """
    if not isinstance(seed, int):
        raise TypeError("Seed must be an integer")
    if not isinstance(range, tuple):
        raise TypeError("Range must be a tuple")
    if not isinstance(range[0], int) or not isinstance(range[1], int):
        raise TypeError("Range must be a tuple of integers")

    import random
    return random.Random(seed).randint(*range)

这段代码定义了一个名为 random_number_generator 的函数,并使用 @register_tool 装饰器将其注册为一个工具。

函数接受两个参数 seed 和 range,并返回一个整数。下面是对代码的详细解释:

  • @register_tool:@ 符号是装饰器语法,用于在函数定义之前修饰函数。@register_tool 表示将该函数注册为一个工具。具体工具注册的逻辑在你提供的代码中没有呈现,可以在其他地方找到。

  • def random_number_generator(seed: Annotated[int, ‘The random seed used by the generator’, True], range: Annotated[tuple[int, int], ‘The range of the generated numbers’, True]) -> int::这是函数的定义部分。函数名为 random_number_generator,接受两个参数 seed 和 range,并指定返回类型为整数。

  • “”" Generates a random number x, s.t. range[0] <= x < range[1] “”":这是函数的文档字符串(docstring),用于描述函数的功能。根据文档字符串的描述,该函数生成一个介于 range[0] 和 range[1] 之间的随机整数 x。

  • 参数验证部分:在函数体内部,对传入的参数进行验证,确保它们具有正确的类型和值。

    • if not isinstance(seed, int)::检查 seed 是否为整数类型,如果不是,则抛出 TypeError 异常,提示 “Seed must be an integer”。
    • if not isinstance(range, tuple)::检查 range 是否为元组类型,如果不是,则抛出 TypeError 异常,提示 “Range must be a tuple”。
    • if not isinstance(range[0], int) or not isinstance(range[1], int)::检查 range 的元素是否为整数类型,如果不是,则抛出 TypeError 异常,提示 “Range must be a tuple of integers”。
  • import random:导入 Python 标准库中的 random 模块,用于生成随机数。

  • return random.Random(seed).randint(*range):使用 random 模块生成一个随机整数,并将其作为函数的返回值。random.Random(seed) 创建了一个具有指定种子 seed 的随机数生成器对象,然后使用 randint(*range) 方法生成介于 range[0] 和 range[1] 之间的随机整数。

总之,这段代码定义了一个将参数验证和随机数生成结合在一起的函数。它使用装饰器将函数注册为一个工具,并在调用时生成指定范围内的随机整数。

@register_tool
def get_weather(
    city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
    """
    Get the current weather for `city_name`
    """

    if not isinstance(city_name, str):
        raise TypeError("City name must be a string")

    key_selection = {
        "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc",  "observation_time"],
    }
    import requests
    try:
        resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
        resp.raise_for_status()
        resp = resp.json()
        ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
    except:
        import traceback
        ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() 

    return str(ret)

这段代码定义了一个名为 get_weather 的函数,并使用 @register_tool 装饰器将其注册为一个工具。

函数接受一个参数 city_name,并返回一个字符串。下面是对代码的详细解释:

  • @register_tool:@ 符号是装饰器语法,用于在函数定义之前修饰函数。@register_tool 表示将该函数注册为一个工具。具体工具注册的逻辑在你提供的代码中没有呈现,可以在其他地方找到。

  • def get_weather(city_name: Annotated[str, ‘The name of the city to be queried’, True]) -> str::这是函数的定义部分。函数名为 get_weather,接受一个 city_name 参数,指定返回类型为字符串。

  • “”" Get the current weather for city_name “”":这是函数的文档字符串(docstring),用于描述函数的功能。根据文档字符串的描述,该函数用于获取指定城市的当前天气情况。

  • 参数验证部分:在函数体内部,对传入的参数进行验证,确保它们具有正确的类型和值。

    • if not isinstance(city_name, str)::检查 city_name 是否为字符串类型,如果不是,则抛出 TypeError 异常,提示 “City name must be a string”。
  • key_selection = {…}:定义了一个字典变量 key_selection,用于存储需要从 API 响应中提取的天气信息的键值选择。该字典的键代表不同的天气信息,而对应的值是一个列表,包含了该天气信息所对应的子键。

  • import requests:导入 Python 第三方库 requests,用于发送 HTTP 请求。

  • try::尝试执行一段代码,并捕获可能的异常。

    • resp = requests.get(f"https://wttr.in/{city_name}?format=j1"):使用 requests 发送一个 GET 请求,获取指定城市的天气数据。URL 中的 {city_name} 部分会被替换为实际的城市名称。
    • resp.raise_for_status():检查请求的状态码,如果是错误的状态码,将抛出一个异常。
    • resp = resp.json():将响应的 JSON 数据解析为 Python 字典,并将其赋值给 resp 变量。
    • ret = {…}:根据 key_selection 字典中的键值选择,从响应中提取相应的天气信息,存储在 ret 变量中。这里使用了字典推导式来生成结果。
  • except::捕获可能的异常。

    • import traceback:导入 Python 标准库中的 traceback 模块,用于获取异常的堆栈信息。
    • ret = “Error encountered while fetching weather data!\n” + traceback.format_exc():如果发生异常,将错误提示信息和堆栈信息拼接成一个字符串,并将其赋值给 ret 变量。
  • return str(ret):返回结果,将结果转换为字符串类型后返回。

总之,这段代码定义了一个用于获取指定城市天气的函数。它使用 requests 库发送 HTTP 请求获取天气数据,并从响应中提取指定的天气信息。如果发生任何异常,它会将错误提示信息和堆栈信息返回。

请注意,这段代码中的 @register_tool 装饰器和 requests 库是额外的依赖项,你可能需要在其他地方找到这些实现或库的定义。

完结!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值