如何使用java写一个agent

1.此处先不使用继承,单独一个类写一个agent原型

public class MyReActAgent {
    private static final Logger logger = LoggerFactory.getLogger(MyReActAgent.class);

    private ChatModel llm;
    private ToolExecutor toolExecutor;
    private List<Tool> tools;
    private InMemoryMemory memory;
    private int maxIterations = 5;
    private String systemPrompt = null;

    public MyReActAgent(ChatModel llm, List<Tool> tools) {
        this.llm = llm;
        this.tools = tools;
        this.toolExecutor = new ToolExecutor(tools);
        this.memory = new InMemoryMemory(10); // 初始化记忆
    }

    public MyReActAgent(ChatModel llm, List<Tool> tools, String systemPrompt) {
        this(llm, tools);
        this.systemPrompt = systemPrompt;
    }

    public String run(String input) {
        logger.info("Agent started with input: {}", input);
        memory.add("user", input);

        StringBuilder fullPrompt = new StringBuilder();
        if (systemPrompt != null && !systemPrompt.isEmpty()) {
            fullPrompt.append(systemPrompt).append("\n\n");
        }
        buildToolPrompt(fullPrompt);

        fullPrompt.append("Begin!\n")
                  .append("Question: ").append(input).append("\n");

        for (int i = 0; i < maxIterations; i++) {
            logger.debug("Starting iteration {}/{}", i + 1, maxIterations);
            String response = llm.chat(fullPrompt.toString());
            logger.debug("LLM Response: {}", response);

            // 检查是否有工具调用
            if (hasToolCall(response)) {
                ToolCall toolCall = extractToolCall(response);
                if (toolCall != null) {
                    String toolName = toolCall.getFunction().getName();
                    String args = toolCall.getFunction().getArguments();

                    logger.info("Tool selected: {} with args: {}", toolName, args);

                    String observation = toolExecutor.execute(toolCall);
                    logger.info("Tool executed. Observation: {}", observation);

                    // 记录到记忆
                    memory.add("assistant", "Action: " + toolName + ", Input: " + args);
                    memory.add("system", "Observation: " + observation);

                    // 更新 prompt
                    fullPrompt.append("Thought: I need to use a tool.\n")
                              .append("Action: ").append(toolName).append("\n")
                              .append("Action Input: ").append(args).append("\n")
                              .append("Observation: ").append(observation).append("\n");
                } else {
                    String finalAnswer = response;
                    logger.info("Final answer generated: {}", finalAnswer);
                    memory.add("assistant", finalAnswer);
                    return finalAnswer;
                }
            } else {
                String finalAnswer = response;
                logger.info("Final answer generated: {}", finalAnswer);
                memory.add("assistant", finalAnswer);
                return finalAnswer;
            }
        }

        String fallback = "I couldn't find a solution within " + maxIterations + " steps.";
        logger.warn("Max iterations reached. Returning fallback.");
        memory.add("assistant", fallback);
        return fallback;
    }

    /**
     * 检查响应中是否包含工具调用
     * @param response 模型响应
     * @return 如果包含工具调用返回true,否则返回false
     */
    private boolean hasToolCall(String response) {
        // 检查响应中是否包含Action关键词
        return response != null && response.contains("Action:") && response.contains("Action Input:");
    }

    /**
     * 从模型响应中提取工具调用信息
     * @param response 模型响应
     * @return ToolCall对象,如果未找到则返回null
     */
    private ToolCall extractToolCall(String response) {
        try {
            // 使用正则表达式提取Action和Action Input
            Pattern actionPattern = Pattern.compile("Action:\\s*(\\w+)");
            Pattern inputPattern = Pattern.compile("Action Input:\\s*(\\{.*?\\}|\\[.*?\\]|\".*?\"|\\d+(?:\\.\\d+)?|true|false|null)");
            
            Matcher actionMatcher = actionPattern.matcher(response);
            Matcher inputMatcher = inputPattern.matcher(response);
            
            if (actionMatcher.find() && inputMatcher.find()) {
                String toolName = actionMatcher.group(1);
                String arguments = inputMatcher.group(1);
                
                // 创建ToolCall对象
                ToolCall toolCall = new ToolCall();
                toolCall.setType("function");
                FunctionCall functionCall = new FunctionCall(toolName, arguments);
                toolCall.setFunction(functionCall);
                
                return toolCall;
            }
        } catch (Exception e) {
            logger.error("Error extracting tool call from response: {}", e.getMessage(), e);
        }
        return null;
    }

    private void buildToolPrompt(StringBuilder prompt) {
        prompt.append("Answer the following question. You can use tools.\n");
        prompt.append("Available tools:\n");

        for (Tool tool : tools) {
            prompt.append("- ").append(tool.getName())
                  .append(": ").append(tool.getDescription())
                  .append("\n  Parameters: ").append(tool.getParametersSchema())
                  .append("\n");
        }

        prompt.append("\nUse the following format:\n")
              .append("Thought: you should always think about what to do\n")
              .append("Action: the action to take, should be one of [")
              .append(String.join(", ", tools.stream().map(Tool::getName).toArray(String[]::new)))
              .append("]\n")
              .append("Action Input: the input to the action, can be JSON object, array, string, number, boolean or null\n")
              .append("Observation: the result of the action\n")
              .append("... (Thought/Action/Observation can repeat)\n")
              .append("Thought: I now know the final answer\n")
              .append("Final Answer: the final answer to the original input question\n\n");
    }

    // 获取记忆内容(可用于调试或上下文增强)
    public InMemoryMemory getMemory() {
        return memory;
    }

    // 设置最大迭代次数
    public void setMaxIterations(int maxIterations) {
        if (maxIterations <= 0) {
            throw new IllegalArgumentException("maxIterations must be positive");
        }
        this.maxIterations = maxIterations;
    }

    // 设置系统提示词
    public void setSystemPrompt(String systemPrompt) {
        this.systemPrompt = systemPrompt;
    }

    // 获取系统提示词
    public String getSystemPrompt() {
        return systemPrompt;
    }
}

2. ReacAct agent 需要支持记忆、工具调用等功能

创建工具类,工具需要有工具描述、工具名称、工具执行的具体逻辑。

public interface Tool {
    String getName();
    String getDescription();
    String getParametersSchema(); // JSON Schema
    String execute(String arguments); // arguments 是 JSON 字符串
}

public class ToolExecutor {
    private final Map<String, Tool> toolRegistry;

    public ToolExecutor(List<Tool> tools) {
        this.toolRegistry = new HashMap<>();
        for (Tool tool : tools) {
            toolRegistry.put(tool.getName(), tool);
        }
    }

    public String execute(ToolCall toolCall) {
        String toolName = toolCall.getFunction().getName();
        String arguments = toolCall.getFunction().getArguments();

        Tool tool = toolRegistry.get(toolName);
        if (tool == null) {
            return "{\"error\": \"Tool not found: " + toolName + "\"}";
        }

        return tool.execute(arguments);
    }
}

3.示例:一个实现的工具

public class WeatherTool implements Tool {
    private static final ObjectMapper mapper = new ObjectMapper(); // 共享实例

    @Override
    public String getName() {
        return "get_weather";
    }

    @Override
    public String getDescription() {
        return "获取指定城市的天气信息";
    }

    @Override
    public String getParametersSchema() {
        return """
        {
          "type": "object",
          "properties": {
            "location": {
              "type": "string",
              "description": "城市名称"
            }
          },
          "required": ["location"]
        }
        """;
    }
    @Override
    public String execute(String arguments) {
        try {
            JsonNode node = mapper.readTree(arguments);
            if (!node.has("location")) {
                return "{\"error\": \"Missing required field: location\"}";
            }
            String location = node.get("location").asText();

            // 模拟天气数据
            return String.format("{\"location\": \"%s\", \"temperature\": \"28°C\", \"condition\": \"Sunny\"}", location);
        } catch (Exception e) {
            return String.format("{\"error\": \"Invalid JSON: %s\"}", e.getMessage());
        }
    }

}

4. 工具调用需要的参数封装

public class ToolCall {
    @JsonProperty("id")
    private String id;

    @JsonProperty("type")
    private String type;

    @JsonProperty("function")
    private FunctionCall function;

    // Getters and Setters
    public String getId() { return id; }
    public void setId(String id) { this.id = id; }

    public String getType() { return type; }
    public void setType(String type) { this.type = type; }

    public FunctionCall getFunction() { return function; }
    public void setFunction(FunctionCall function) { this.function = function; }
}


public class FunctionCall {
    @JsonProperty("name")
    private String name;

    @JsonProperty("arguments")
    private String arguments; // JSON string

    // 构造函数
    public FunctionCall() {}

    public FunctionCall(String name, String arguments) {
        this.name = name;
        this.arguments = arguments;
    }

    // Getters and Setters
    public String getName() { return name; }
    public void setName(String name) { this.name = name; }

    public String getArguments() { return arguments; }
    public void setArguments(String arguments) { this.arguments = arguments; }
}

5. 调用agent

 reActAgent = new MyReActAgent(openAiChatModel, availableTools);
response = reActAgent.run(userMessageText);
5.1. 后续可以改善工具的形式,或者使用spring-ai自带的tool注解注册形式
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值