引言
GPT3.5作为一款功能强大的语言模型,在文本生成、翻译和问答等方面展现出卓越的能力。然而,它在数学计算方面却表现得有些力不从心,经常出现错误结果(见下图)。为什么呢?
要解答这个问题,我们需要从GPT3.5的模型架构出发。GPT3.5是一款基于自回归模型的语言模型,这种模型的核心思想是利用条件概率来预测文本序列中的下一个元素。简单来说,就是根据已有的词或字符序列,来预测接下来最可能出现的词或字符。这种基于上下文信息的预测方式,使得GPT3.5在理解和生成自然语言方面表现出色。
然而,在处理数学计算时,情况就有所不同了。数学计算往往需要遵循严格的逻辑规则和算法步骤,这些规则和步骤并不总是能够简单地通过文本序列中的上下文信息来推断。
为了解决这个问题,我们可以使用 Function Calling 技术调用外部工具进行数学计算。Function Calling 允许 GPT-3.5将自然语言请求转换为 API 调用进行数学计算,从而获得更精准的结果。效果如下图。为了更好理解Agent的原理,本文不采用任何Agent框架(例如LangChain、LangGraph)。
实现步骤
-
安装依赖
pip install openai requests
-
import用到的包
-
import json import os import requests import urllib.parse from typing import List, Iterable, Union from openai import OpenAI from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from openai.types.chat.chat_completion_user_message_param import ChatCompletionUserMessageParam from openai.types.chat.chat_completion_tool_message_param import ChatCompletionToolMessageParam from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam
-
创建OpenAI client
-
api_key = os.getenv('OPENAI_API_KEY') base_url = os.getenv("OPENAI_API_BASE") model = "gpt-3.5-turbo" stream = True client = OpenAI( api_key=api_key, base_url=base_url )
测试前需要设置环境变量OPENAI_API_KEY。如果你没有openai的api key,可以申请阿里云的通义千问(https://dashscope.console.aliyun.com),并做以下修改。
注意:通义千问Function Calling不支持流式输出,需要设置stream=False,用户体验相对流式输出差。
-
api_key = "替换为你的API Key" base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" model = "qwen-plus" stream = False
-
定义tools(数学计算、查询天气和联网查询)
-
tools = [ { "type": "function", "function": { "name": "calculator", "description": "perform mathematical operation", "parameters": { "type": "object", "properties": { "expression": { "type": "string", "description": "the mathematical expression" } }, "required": [ "expression" ] } } }, { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g. San Francisco, CA" } }, "required": [ "location" ] } } }, { "type": "function", "function": { "name": "search_web", "description": "retrieve the up to date information", "parameters": { "type": "object", "properties": { "keyword": { "type": "string", "description": "the keyword to search" } }, "required": [ "keyword" ] } } } ]
-
数学计算接口实现
-
LLM通过Function Calling调用函数时,会把自然语言的描述转为表达式。所以此处简单用eval执行表达式得到结果。
def calculator(expression: str) -> str: data = {'result': eval(expression)} return str(data)
-
天气查询接口实现
-
联网查询接口实现
-
定义函数invoke_tool,根据tool_call的函数名调用对应的函数
-
def invoke_tool(tool_call: Union[ChatCompletionMessageToolCall, ChoiceDeltaToolCall]) -> ChatCompletionToolMessageParam: result = ChatCompletionToolMessageParam( role="tool", tool_call_id=tool_call.id) func_name = tool_call.function.name args = json.loads(tool_call.function.arguments) if func_name == "calculator": result["content"] = calculator(args['expression']) elif func_name == "get_current_weather": result["content"] = get_current_weather(args['location']) elif func_name == "search_web": result["content"] = search_web(args['keyword']) else: result["content"] = "函数未定义" return result
-
处理流式输出的tool_calls
-
GPT流式输出的例子如下,需要拼接tool_calls信息。
-
data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_WCUtMR24ssozFrdfG8dEdvOW","type":"function","function":{"name":"search_web","arguments":""}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"keyword"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"小"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"米"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"SU"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"7"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"价格"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-9YRB1Tt2TqemuKwTELB7yItNZA2O3","object":"chat.completion.chunk","created":1717994647,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]} data: [DONE]
把tool_calls的合并处理封装为merge_tool_calls
-
def merge_too_calls(tool_calls: List[ChoiceDeltaToolCall], delta_tool_calls: List[ChoiceDeltaToolCall]): for index, delta_tool_call in enumerate(delta_tool_calls): if len(tool_calls) <= index: if delta_tool_call.function.arguments is None: delta_tool_call.function.arguments = '' tool_calls.append(delta_tool_call) continue ref = tool_calls[index] ref.function.arguments += delta_tool_call.function.arguments
-
定义函数main,实现如下流程
-
等待用户输入
-
向LLM发起查询(除了用户的query外,还需要带上tools定义)
-
如果LLM的响应信息有tool_calls信息,需要调用函数,并将调用结果返回给LLM。注意,LLM的响应可能包括多个tool_call(例如用户的query是"北京和广州的天气怎么样")
-
通过变量messages维护上下文信息。用户发送的query、LLM的响应、函数调用结果都需要记录到messages
-
同时支持流式输出和非流式输出
-
def main(): MAX_MESSAGES_NUM = 20 messages: Iterable[ChatCompletionMessageParam] = list() needInput = True while True: # 只保留20条消息作为上下文 if len(messages) > MAX_MESSAGES_NUM: messages = messages[-MAX_MESSAGES_NUM:] while len(messages) > 0: role = messages[0]['role'] if role == 'system' or role == 'user': break messages = messages[1:] # 等待用户输入 if needInput: query = input("\n>>>> 请输入问题:").strip() if query == "": continue messages.append(ChatCompletionUserMessageParam( role="user", content=query)) # 向LLM发起查询(除了用户的query外,还需要带上tools定义) chat_completion = client.chat.completions.create( messages=messages, tools=tools, model=model, stream=stream ) tool_calls = None content = None if stream: # 处理流式输出 for chunk in chat_completion: if len(chunk.choices) == 0: continue delta = chunk.choices[0].delta if isinstance(delta.tool_calls, list): if tool_calls is None: tool_calls = [] merge_too_calls(tool_calls, delta.tool_calls) elif isinstance(delta.content, str): if content is None: content = "" content += delta.content print(delta.content, end='', flush=True) else: # 非流式输出 tool_calls = chat_completion.choices[0].message.tool_calls content = chat_completion.choices[0].message.content if isinstance(tool_calls, list) and len(tool_calls) > 0: # LLM的响应信息有tool_calls信息 needInput = False messages.append(ChatCompletionAssistantMessageParam( role="assistant", tool_calls=tool_calls, content='')) # 注意:LLM的响应可能包括多个tool_call for tool_call in tool_calls: result = invoke_tool(tool_call) messages.append(result) else: needInput = True if isinstance(content, str) and len(content) > 0: if not stream: print(content) messages.append(ChatCompletionAssistantMessageParam( role="assistant", content=content)) main()
-
运行测试,效果见本文开始的GIF动图。至此,这个Agent已经具备数学计算、天气查询和联网查询功能。
-