-
引入依赖
pom文件<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <parent> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-parent</artifactId> <version>3.2.4</version> <relativePath/> <!-- lookup parent from repository --> </parent> <!-- Generated by https://start.springboot.io --> <!-- 优质的 spring/boot/data/security/cloud 框架中文文档尽在 => https://springdoc.cn --> <groupId>com.example</groupId> <artifactId>spring-ai-demo</artifactId> <version>0.0.1-SNAPSHOT</version> <name>spring-ai-demo</name> <description>spring-ai-demo</description> <properties> <java.version>17</java.version> <spring-ai.version>0.8.1</spring-ai.version> </properties> <dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-openai-spring-boot-starter</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-test</artifactId> <scope>test</scope> </dependency> </dependencies> <dependencyManagement> <dependencies> <dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-bom</artifactId> <version>${spring-ai.version}</version> <type>pom</type> <scope>import</scope> </dependency> </dependencies> </dependencyManagement> <build> <plugins> <plugin> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-maven-plugin</artifactId> </plugin> </plugins> </build> <repositories> <repository> <id>spring-milestones</id> <name>Spring Milestones</name> <url>https://repo.spring.io/milestone</url> <snapshots> <enabled>false</enabled> </snapshots> </repository> </repositories> </project>
-
配置yml
spring: ai: openai: api-key: #没有账号,可以在淘宝,2元购买key base-url: # 淘宝上会给你中转的url retry: max-attempts: 3
-
Controller
@RestController public class IndexController { // key:sessionId value:会话聊天的上下文 private static final Map<String, List<Message>> messageMap = new HashMap<>(); @Resource private OpenAiChatClient chatClient; @GetMapping public String index(HttpSession session, @RequestParam(name = "message", defaultValue = "使用Java,写一个冒泡算法") String message) { // 检查是否已有会话 List<Message> messageList = messageMap.get(session.getId()); if (messageList == null) { messageList = new ArrayList<>(); messageMap.put(session.getId(), messageList); } // 将用户消息,加入上下文 messageList.add(messageList.size(), new UserMessage(message)); // 发送消息时,传递的是上下文的所有信息,不单单是你当前发送的一条消息 // 所有最好为list设置容量的限制,不然你的api-key的资源会消耗很快 String result = chatClient.call(new Prompt(messageList)).getResult().getOutput().getContent(); // 将ai消息,加入上下文 messageList.add(messageList.size(), new AssistantMessage(result)); return result; } }
优化代码:(可以使用redis来代替 存储messageMap)
OpenAiService.java/** * 记录上下文聊天信息 */ @Service public class OpenAiService { private static final Map<String, LinkedList<Message>> messageMap = new HashMap<>(); @Resource private OpenAiChatClient chatClient; private void addUserMessage(LinkedList<Message> messageList, String message) { checkMessageCapacity(messageList); messageList.addLast(new UserMessage(message)); } private void addAssistantMessage(LinkedList<Message> messageList, String message) { checkMessageCapacity(messageList); messageList.add(messageList.size(), new AssistantMessage(message)); } public String chat(String sessionId, String message) { LinkedList<Message> messageList = messageMap.get(sessionId); if (messageList == null) { messageList = new LinkedList<>(); messageMap.put(sessionId, messageList); } addUserMessage(messageList, message); String result = chatClient.call(new Prompt(messageList)).getResult().getOutput().getContent(); addAssistantMessage(messageList, result); return result; } public void checkMessageCapacity(LinkedList<Message> messages) { if (messages.size() >= 10) { messages.removeFirst(); } } }
IndexController.java
@RestController public class IndexController { @Resource private OpenAiService openAiService; @GetMapping public String index(HttpSession session, @RequestParam(name = "message", defaultValue = "使用Java,写一个冒泡算法") String message) { String result = ""; synchronized (session.getId().intern()) { result = openAiService.chat(session.getId(), message); } return result; } }
Spring-AI-上下文记忆
最新推荐文章于 2025-04-13 23:54:41 发布