1.此处先不使用继承,单独一个类写一个agent原型
public class MyReActAgent {
private static final Logger logger = LoggerFactory.getLogger(MyReActAgent.class);
private ChatModel llm;
private ToolExecutor toolExecutor;
private List<Tool> tools;
private InMemoryMemory memory;
private int maxIterations = 5;
private String systemPrompt = null;
public MyReActAgent(ChatModel llm, List<Tool> tools) {
this.llm = llm;
this.tools = tools;
this.toolExecutor = new ToolExecutor(tools);
this.memory = new InMemoryMemory(10); // 初始化记忆
}
public MyReActAgent(ChatModel llm, List<Tool> tools, String systemPrompt) {
this(llm, tools);
this.systemPrompt = systemPrompt;
}
public String run(String input) {
logger.info("Agent started with input: {}", input);
memory.add("user", input);
StringBuilder fullPrompt = new StringBuilder();
if (systemPrompt != null && !systemPrompt.isEmpty()) {
fullPrompt.append(systemPrompt).append("\n\n");
}
buildToolPrompt(fullPrompt);
fullPrompt.append("Begin!\n")
.append("Question: ").append(input).append("\n");
for (int i = 0; i < maxIterations; i++) {
logger.debug("Starting iteration {}/{}", i + 1, maxIterations);
String response = llm.chat(fullPrompt.toString());
logger.debug("LLM Response: {}", response);
// 检查是否有工具调用
if (hasToolCall(response)) {
ToolCall toolCall = extractToolCall(response);
if (toolCall != null) {
String toolName = toolCall.getFunction().getName();
String args = toolCall.getFunction().getArguments();
logger.info("Tool selected: {} with args: {}", toolName, args);
String observation = toolExecutor.execute(toolCall);
logger.info("Tool executed. Observation: {}", observation);
// 记录到记忆
memory.add("assistant", "Action: " + toolName + ", Input: " + args);
memory.add("system", "Observation: " + observation);
// 更新 prompt
fullPrompt.append("Thought: I need to use a tool.\n")
.append("Action: ").append(toolName).append("\n")
.append("Action Input: ").append(args).append("\n")
.append("Observation: ").append(observation).append("\n");
} else {
String finalAnswer = response;
logger.info("Final answer generated: {}", finalAnswer);
memory.add("assistant", finalAnswer);
return finalAnswer;
}
} else {
String finalAnswer = response;
logger.info("Final answer generated: {}", finalAnswer);
memory.add("assistant", finalAnswer);
return finalAnswer;
}
}
String fallback = "I couldn't find a solution within " + maxIterations + " steps.";
logger.warn("Max iterations reached. Returning fallback.");
memory.add("assistant", fallback);
return fallback;
}
/**
* 检查响应中是否包含工具调用
* @param response 模型响应
* @return 如果包含工具调用返回true,否则返回false
*/
private boolean hasToolCall(String response) {
// 检查响应中是否包含Action关键词
return response != null && response.contains("Action:") && response.contains("Action Input:");
}
/**
* 从模型响应中提取工具调用信息
* @param response 模型响应
* @return ToolCall对象,如果未找到则返回null
*/
private ToolCall extractToolCall(String response) {
try {
// 使用正则表达式提取Action和Action Input
Pattern actionPattern = Pattern.compile("Action:\\s*(\\w+)");
Pattern inputPattern = Pattern.compile("Action Input:\\s*(\\{.*?\\}|\\[.*?\\]|\".*?\"|\\d+(?:\\.\\d+)?|true|false|null)");
Matcher actionMatcher = actionPattern.matcher(response);
Matcher inputMatcher = inputPattern.matcher(response);
if (actionMatcher.find() && inputMatcher.find()) {
String toolName = actionMatcher.group(1);
String arguments = inputMatcher.group(1);
// 创建ToolCall对象
ToolCall toolCall = new ToolCall();
toolCall.setType("function");
FunctionCall functionCall = new FunctionCall(toolName, arguments);
toolCall.setFunction(functionCall);
return toolCall;
}
} catch (Exception e) {
logger.error("Error extracting tool call from response: {}", e.getMessage(), e);
}
return null;
}
private void buildToolPrompt(StringBuilder prompt) {
prompt.append("Answer the following question. You can use tools.\n");
prompt.append("Available tools:\n");
for (Tool tool : tools) {
prompt.append("- ").append(tool.getName())
.append(": ").append(tool.getDescription())
.append("\n Parameters: ").append(tool.getParametersSchema())
.append("\n");
}
prompt.append("\nUse the following format:\n")
.append("Thought: you should always think about what to do\n")
.append("Action: the action to take, should be one of [")
.append(String.join(", ", tools.stream().map(Tool::getName).toArray(String[]::new)))
.append("]\n")
.append("Action Input: the input to the action, can be JSON object, array, string, number, boolean or null\n")
.append("Observation: the result of the action\n")
.append("... (Thought/Action/Observation can repeat)\n")
.append("Thought: I now know the final answer\n")
.append("Final Answer: the final answer to the original input question\n\n");
}
// 获取记忆内容(可用于调试或上下文增强)
public InMemoryMemory getMemory() {
return memory;
}
// 设置最大迭代次数
public void setMaxIterations(int maxIterations) {
if (maxIterations <= 0) {
throw new IllegalArgumentException("maxIterations must be positive");
}
this.maxIterations = maxIterations;
}
// 设置系统提示词
public void setSystemPrompt(String systemPrompt) {
this.systemPrompt = systemPrompt;
}
// 获取系统提示词
public String getSystemPrompt() {
return systemPrompt;
}
}
2. ReacAct agent 需要支持记忆、工具调用等功能
创建工具类,工具需要有工具描述、工具名称、工具执行的具体逻辑。
public interface Tool {
String getName();
String getDescription();
String getParametersSchema(); // JSON Schema
String execute(String arguments); // arguments 是 JSON 字符串
}
public class ToolExecutor {
private final Map<String, Tool> toolRegistry;
public ToolExecutor(List<Tool> tools) {
this.toolRegistry = new HashMap<>();
for (Tool tool : tools) {
toolRegistry.put(tool.getName(), tool);
}
}
public String execute(ToolCall toolCall) {
String toolName = toolCall.getFunction().getName();
String arguments = toolCall.getFunction().getArguments();
Tool tool = toolRegistry.get(toolName);
if (tool == null) {
return "{\"error\": \"Tool not found: " + toolName + "\"}";
}
return tool.execute(arguments);
}
}
3.示例:一个实现的工具
public class WeatherTool implements Tool {
private static final ObjectMapper mapper = new ObjectMapper(); // 共享实例
@Override
public String getName() {
return "get_weather";
}
@Override
public String getDescription() {
return "获取指定城市的天气信息";
}
@Override
public String getParametersSchema() {
return """
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "城市名称"
}
},
"required": ["location"]
}
""";
}
@Override
public String execute(String arguments) {
try {
JsonNode node = mapper.readTree(arguments);
if (!node.has("location")) {
return "{\"error\": \"Missing required field: location\"}";
}
String location = node.get("location").asText();
// 模拟天气数据
return String.format("{\"location\": \"%s\", \"temperature\": \"28°C\", \"condition\": \"Sunny\"}", location);
} catch (Exception e) {
return String.format("{\"error\": \"Invalid JSON: %s\"}", e.getMessage());
}
}
}
4. 工具调用需要的参数封装
public class ToolCall {
@JsonProperty("id")
private String id;
@JsonProperty("type")
private String type;
@JsonProperty("function")
private FunctionCall function;
// Getters and Setters
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public String getType() { return type; }
public void setType(String type) { this.type = type; }
public FunctionCall getFunction() { return function; }
public void setFunction(FunctionCall function) { this.function = function; }
}
public class FunctionCall {
@JsonProperty("name")
private String name;
@JsonProperty("arguments")
private String arguments; // JSON string
// 构造函数
public FunctionCall() {}
public FunctionCall(String name, String arguments) {
this.name = name;
this.arguments = arguments;
}
// Getters and Setters
public String getName() { return name; }
public void setName(String name) { this.name = name; }
public String getArguments() { return arguments; }
public void setArguments(String arguments) { this.arguments = arguments; }
}
5. 调用agent
reActAgent = new MyReActAgent(openAiChatModel, availableTools);
response = reActAgent.run(userMessageText);
625

被折叠的 条评论
为什么被折叠?



